diff --git a/CHANGELOG.md b/CHANGELOG.md index e288ab3265..ec512fb621 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -157,6 +157,7 @@ Up requirements to Django 4.0+, Twisted 22+, Python 3.9 or 3.10 except more standard aliases logger.error/info/exception/debug etc can now be used. - Have `type/force` default to `update`-mode rather than `reset`mode and add more verbose warning when using reset mode. +- Attribute storage support defaultdics (Hendher) ## Evennia 0.9.5 diff --git a/evennia/utils/dbserialize.py b/evennia/utils/dbserialize.py index 420f124197..ce6743bc5b 100644 --- a/evennia/utils/dbserialize.py +++ b/evennia/utils/dbserialize.py @@ -203,6 +203,10 @@ class _SaverMutable(object): dat = _SaverDict(_parent=parent) dat._data.update((key, process_tree(val, dat)) for key, val in item.items()) return dat + elif dtype == defaultdict: + dat = _SaverDefaultDict(item.default_factory, _parent=parent) + dat._data.update((key, process_tree(val, dat)) for key, val in item.items()) + return dat elif dtype == set: dat = _SaverSet(_parent=parent) dat._data.update(process_tree(val, dat) for val in item) @@ -309,6 +313,25 @@ class _SaverDict(_SaverMutable, MutableMapping): self._data.update(*args, **kwargs) +class _SaverDefaultDict(_SaverDict): + """ + A defaultdict that stores changes to an attribute when updated + """ + + def __init__(self, factory, *args, **kwargs): + super().__init__(*args, **kwargs) + self._data = defaultdict(factory) + self.default_factory = factory + + def __getitem__(self, key): + if key not in self._data.keys(): + # detect the case of db.foo['a'] with no immediate assignment + # (important: using `key in self._data` would be always True!) + default_value = self._data[key] + self.__setitem__(key, default_value) + return self._data[key] + + class _SaverSet(_SaverMutable, MutableSet): """ A set that saves to an Attribute when updated @@ -407,6 +430,7 @@ _DESERIALIZE_MAPPING = { _SaverSet.__name__: set, _SaverOrderedDict.__name__: OrderedDict, _SaverDeque.__name__: deque, + _SaverDefaultDict.__name__: defaultdict, } @@ -418,12 +442,15 @@ def deserialize(obj): """ def _iter(obj): + # breakpoint() typ = type(obj) tname = typ.__name__ if tname in ("_SaverDict", "dict"): return {_iter(key): _iter(val) for key, val in obj.items()} elif tname in ("_SaverOrderedDict", "OrderedDict"): return OrderedDict([(_iter(key), _iter(val)) for key, val in obj.items()]) + elif tname in ("_SaverDefaultDict", "defaultdict"): + return defaultdict(obj.default_factory, {_iter(key): _iter(val) for key, val in obj.items()}) elif tname in _DESERIALIZE_MAPPING: return _DESERIALIZE_MAPPING[tname](_iter(val) for val in obj) elif is_iter(obj): @@ -584,6 +611,8 @@ def to_pickle(data): return [process_item(val) for val in item] elif dtype in (dict, _SaverDict): return dict((process_item(key), process_item(val)) for key, val in item.items()) + elif dtype in (defaultdict, _SaverDefaultDict): + return defaultdict(item.default_factory, ((process_item(key), process_item(val)) for key, val in item.items())) elif dtype in (set, _SaverSet): return set(process_item(val) for val in item) elif dtype in (OrderedDict, _SaverOrderedDict): @@ -635,6 +664,7 @@ def from_pickle(data, db_obj=None): def process_item(item): """Recursive processor and identification of data""" + # breakpoint() dtype = type(item) if dtype in (str, int, float, bool, bytes, SafeString): return item @@ -647,6 +677,8 @@ def from_pickle(data, db_obj=None): return tuple(process_item(val) for val in item) elif dtype == dict: return dict((process_item(key), process_item(val)) for key, val in item.items()) + elif dtype == defaultdict: + return defaultdict(item.default_factory, ((process_item(key), process_item(val)) for key, val in item.items())) elif dtype == set: return set(process_item(val) for val in item) elif dtype == OrderedDict: @@ -664,6 +696,7 @@ def from_pickle(data, db_obj=None): def process_tree(item, parent): """Recursive processor, building a parent-tree from iterable data""" + # breakpoint() dtype = type(item) if dtype in (str, int, float, bool, bytes, SafeString): return item @@ -682,6 +715,12 @@ def from_pickle(data, db_obj=None): (process_item(key), process_tree(val, dat)) for key, val in item.items() ) return dat + elif dtype == defaultdict: + dat = _SaverDefaultDict(item.default_factory, _parent=parent) + dat._data.update( + (process_item(key), process_tree(val, dat)) for key, val in item.items() + ) + return dat elif dtype == set: dat = _SaverSet(_parent=parent) dat._data.update(set(process_tree(val, dat) for val in item)) @@ -721,6 +760,12 @@ def from_pickle(data, db_obj=None): (process_item(key), process_tree(val, dat)) for key, val in data.items() ) return dat + elif dtype == defaultdict: + dat = _SaverDefaultDict(data.default_factory, _db_obj=db_obj) + dat._data.update( + (process_item(key), process_tree(val, dat)) for key, val in data.items() + ) + return dat elif dtype == set: dat = _SaverSet(_db_obj=db_obj) dat._data.update(process_tree(val, dat) for val in data) diff --git a/evennia/utils/tests/test_dbserialize.py b/evennia/utils/tests/test_dbserialize.py index c0e23b928a..028d6d1f72 100644 --- a/evennia/utils/tests/test_dbserialize.py +++ b/evennia/utils/tests/test_dbserialize.py @@ -88,3 +88,27 @@ class TestDbSerialize(TestCase): self.assertIsInstance(value, base_type) self.assertNotIsInstance(value, saver_type) self.assertEqual(value, default_value) + self.obj.db.test = {'a': True} + self.obj.db.test.update({'b': False}) + self.assertEqual(self.obj.db.test, {'a': True, 'b': False}) + + def test_defaultdict(self): + from collections import defaultdict + # baseline behavior for a defaultdict + _dd = defaultdict(list) + _dd['a'] + self.assertEqual(_dd, {'a': []}) + + # behavior after defaultdict is set as attribute + + dd = defaultdict(list) + self.obj.db.test = dd + self.obj.db.test['a'] + self.assertEqual(self.obj.db.test, {'a': []}) + + self.obj.db.test['a'].append(1) + self.assertEqual(self.obj.db.test, {'a': [1]}) + self.obj.db.test['a'].append(2) + self.assertEqual(self.obj.db.test, {'a': [1, 2]}) + self.obj.db.test['a'].append(3) + self.assertEqual(self.obj.db.test, {'a': [1, 2, 3]})