Implement dict update operator (|) for savers.

This commit is contained in:
Owllex 2022-05-22 00:22:12 -07:00
parent 0c62046224
commit 97c0d7eb30
2 changed files with 33 additions and 17 deletions

View file

@ -239,6 +239,9 @@ class _SaverMutable(object):
def __gt__(self, other):
return self._data > other
def __or__(self, other):
return self._data | other
@_save
def __setitem__(self, key, value):
self._data.__setitem__(key, self._convert_mutables(value))
@ -450,7 +453,9 @@ def deserialize(obj):
elif tname in ("_SaverOrderedDict", "OrderedDict"):
return OrderedDict([(_iter(key), _iter(val)) for key, val in obj.items()])
elif tname in ("_SaverDefaultDict", "defaultdict"):
return defaultdict(obj.default_factory, {_iter(key): _iter(val) for key, val in obj.items()})
return defaultdict(
obj.default_factory, {_iter(key): _iter(val) for key, val in obj.items()}
)
elif tname in _DESERIALIZE_MAPPING:
return _DESERIALIZE_MAPPING[tname](_iter(val) for val in obj)
elif is_iter(obj):
@ -612,7 +617,10 @@ def to_pickle(data):
elif dtype in (dict, _SaverDict):
return dict((process_item(key), process_item(val)) for key, val in item.items())
elif dtype in (defaultdict, _SaverDefaultDict):
return defaultdict(item.default_factory, ((process_item(key), process_item(val)) for key, val in item.items()))
return defaultdict(
item.default_factory,
((process_item(key), process_item(val)) for key, val in item.items()),
)
elif dtype in (set, _SaverSet):
return set(process_item(val) for val in item)
elif dtype in (OrderedDict, _SaverOrderedDict):
@ -678,7 +686,10 @@ def from_pickle(data, db_obj=None):
elif dtype == dict:
return dict((process_item(key), process_item(val)) for key, val in item.items())
elif dtype == defaultdict:
return defaultdict(item.default_factory, ((process_item(key), process_item(val)) for key, val in item.items()))
return defaultdict(
item.default_factory,
((process_item(key), process_item(val)) for key, val in item.items()),
)
elif dtype == set:
return set(process_item(val) for val in item)
elif dtype == OrderedDict:

View file

@ -62,10 +62,12 @@ class TestDbSerialize(TestCase):
self.obj.db.test.sort(key=lambda d: str(d))
self.assertEqual(self.obj.db.test, [{0: 1}, {1: 0}])
def test_dict(self):
def test_saverdict(self):
self.obj.db.test = {"a": True}
self.obj.db.test.update({"b": False})
self.assertEqual(self.obj.db.test, {"a": True, "b": False})
self.obj.db.test |= {"c": 5}
self.assertEqual(self.obj.db.test, {"a": True, "b": False, "c": 5})
@parameterized.expand(
[
@ -88,27 +90,30 @@ class TestDbSerialize(TestCase):
self.assertIsInstance(value, base_type)
self.assertNotIsInstance(value, saver_type)
self.assertEqual(value, default_value)
self.obj.db.test = {'a': True}
self.obj.db.test.update({'b': False})
self.assertEqual(self.obj.db.test, {'a': True, 'b': False})
self.obj.db.test = {"a": True}
self.obj.db.test.update({"b": False})
self.assertEqual(self.obj.db.test, {"a": True, "b": False})
def test_defaultdict(self):
from collections import defaultdict
# baseline behavior for a defaultdict
_dd = defaultdict(list)
_dd['a']
self.assertEqual(_dd, {'a': []})
_dd["a"]
self.assertEqual(_dd, {"a": []})
# behavior after defaultdict is set as attribute
dd = defaultdict(list)
self.obj.db.test = dd
self.obj.db.test['a']
self.assertEqual(self.obj.db.test, {'a': []})
self.obj.db.test["a"]
self.assertEqual(self.obj.db.test, {"a": []})
self.obj.db.test['a'].append(1)
self.assertEqual(self.obj.db.test, {'a': [1]})
self.obj.db.test['a'].append(2)
self.assertEqual(self.obj.db.test, {'a': [1, 2]})
self.obj.db.test['a'].append(3)
self.assertEqual(self.obj.db.test, {'a': [1, 2, 3]})
self.obj.db.test["a"].append(1)
self.assertEqual(self.obj.db.test, {"a": [1]})
self.obj.db.test["a"].append(2)
self.assertEqual(self.obj.db.test, {"a": [1, 2]})
self.obj.db.test["a"].append(3)
self.assertEqual(self.obj.db.test, {"a": [1, 2, 3]})
self.obj.db.test |= {"b": [5, 6]}
self.assertEqual(self.obj.db.test, {"a": [1, 2, 3], "b": [5, 6]})