From c41fd0a33bcc1e87a050a1411766ae48d41345b9 Mon Sep 17 00:00:00 2001 From: Griatch Date: Sat, 18 Apr 2020 14:43:10 +0200 Subject: [PATCH] Add unit tests for base traits --- evennia/contrib/test_traits.py | 342 +++++++++++++++++++++++++++++++-- evennia/contrib/traits.py | 113 ++++++++--- 2 files changed, 407 insertions(+), 48 deletions(-) diff --git a/evennia/contrib/test_traits.py b/evennia/contrib/test_traits.py index 662415d216..6c7bcd1b8f 100644 --- a/evennia/contrib/test_traits.py +++ b/evennia/contrib/test_traits.py @@ -6,13 +6,13 @@ Unit test module for Trait classes. """ -from mock import MagicMock +from copy import copy +from anything import Something +from mock import MagicMock, patch from django.test import TestCase from django.test import override_settings from evennia.utils.test_resources import EvenniaTest -from evennia.contrib.traits import ( - TraitHandler, Trait, NumericTrait, - StaticTrait, CounterTrait, GaugeTrait) +from evennia.contrib import traits class _MockObj: @@ -31,35 +31,337 @@ class _MockObj: assert category == self.category self.dbstore[key] = value +# we want to test the base traits too +_TEST_TRAIT_CLASS_PATHS = [ + "evennia.contrib.traits.Trait", + "evennia.contrib.traits.NumericTrait", + "evennia.contrib.traits.StaticTrait", + "evennia.contrib.traits.CounterTrait", + "evennia.contrib.traits.GaugeTrait", +] -@override_settings(TRAIT_CLASS_PATHS=["evennia.contrib.traits.Trait"]) -class TraitHandlerTest(TestCase): - """Test case for TraitHandler""" - +class _TraitHandlerBase(TestCase): + "Base for trait tests" + @patch("evennia.contrib.traits._TRAIT_CLASS_PATHS", new=_TEST_TRAIT_CLASS_PATHS) def setUp(self): self.obj = _MockObj() - self.traithandler = TraitHandler(self.obj) + self.traithandler = traits.TraitHandler(self.obj) + self.obj.traits = self.traithandler + + def _get_dbstore(self, key): + return self.obj.dbstore['traits'][key] + + +class TraitHandlerTest(_TraitHandlerBase): + """Testing for TraitHandler""" + + def setUp(self): + super().setUp() + self.traithandler.add( + "test1", + name="Test1", + trait_type='trait' + ) + self.traithandler.add( + "test2", + name="Test2", + trait_type='trait', + value=["foo", {"1": [1, 2, 3]}, 4], + ) def test_add_trait(self): - - self.traithandler.add( - "test", - name="Test", - trait_type="trait" - ) - self.assertEqual( - self.obj.dbstore["test"], - {"name": "Test", - "trait_type": "trait"} + self._get_dbstore("test1"), + {"name": "Test1", + "trait_type": 'trait', + "value": None, + } + ) + self.assertEqual( + self._get_dbstore("test2"), + {"name": "Test2", + "trait_type": 'trait', + "value": ["foo", {"1": [1, 2, 3]}, 4], + } + ) + self.assertEqual(len(self.traithandler), 2) + + def test_cache(self): + """ + Cache should not be set until first get + """ + self.assertEqual(len(self.traithandler._cache), 0) + self.traithandler.all # does not affect cache + self.assertEqual(len(self.traithandler._cache), 0) + self.traithandler.test1 + self.assertEqual(len(self.traithandler._cache), 1) + self.traithandler.test2 + self.assertEqual(len(self.traithandler._cache), 2) + + def test_setting(self): + "Don't allow setting stuff on traithandler" + with self.assertRaises(traits.TraitException): + self.traithandler.foo = "bar" + with self.assertRaises(traits.TraitException): + self.traithandler["foo"] = "bar" + + def test_getting(self): + "Test we are getting data from the dbstore" + self.assertEqual( + self.traithandler.test1._data, + {"name": "Test1", "trait_type": "trait", + "value": None} + ) + self.assertEqual( + self.traithandler._cache, Something + ) + self.assertEqual( + self.traithandler.test2._data, + {"name": "Test2", "trait_type": "trait", + "value": ["foo", {"1": [1, 2, 3]}, 4]} + ) + self.assertEqual( + self.traithandler._cache, Something + ) + self.assertFalse(self.traithandler.get("foo")) + self.assertFalse(self.traithandler.bar) + + def test_all(self): + "Test all method" + self.assertEqual(self.traithandler.all, ["test1", "test2"]) + + def test_remove(self): + "Test remove method" + self.traithandler.remove("test2") + self.assertEqual(len(self.traithandler), 1) + self.assertTrue(bool(self.traithandler.get("test1"))) # this populates cache + self.assertEqual(len(self.traithandler._cache), 1) + with self.assertRaises(traits.TraitException): + self.traithandler.remove("foo") + + def test_clear(self): + "Test clear method" + self.traithandler.clear() + self.assertEqual(len(self.traithandler), 0) + + def test_trait_db_connection(self): + "Test that updating a trait property actually updates value in db" + trait = self.traithandler.test1 + self.assertEqual(trait.value, None) + trait.value = 10 + self.assertEqual(trait.value, 10) + self.assertEqual( + self.obj.attributes.get("traits", category="traits")['test1']['value'], + 10 + ) + trait.value = 20 + self.assertEqual(trait.value, 20) + self.assertEqual( + self.obj.attributes.get("traits", category="traits")['test1']['value'], + 20 + ) + del trait.value + self.assertEqual( + self.obj.attributes.get("traits", category="traits")['test1']['value'], + None + ) + + +class TraitTest(_TraitHandlerBase): + """ + Test the base Trait class + """ + + def setUp(self): + super().setUp() + self.traithandler.add( + "test1", + name="Test1", + trait_type="trait", + value="value", + extra_val1="xvalue1", + extra_val2="xvalue2", + ) + self.trait = self.traithandler.get("test1") + + def test_init(self): + self.assertEqual( + self.trait._data, + {"name": "Test1", + "trait_type": "trait", + "value": "value", + "extra_val1": "xvalue1", + "extra_val2": "xvalue2" + } + ) + + def test_validate_input__valid(self): + """Test valid validation input""" + # all data supplied, and extras + dat = { + "name": "Test", + "trait_type": "trait", + "value": 10, + "extra_val": 1000 + } + expected = copy(dat) # we must break link or return === dat always + self.assertEqual(expected, traits.Trait.validate_input(dat)) + + # don't supply value, should get default + dat = { + "name": "Test", + "trait_type": "trait", + # missing value + "extra_val": 1000 + } + expected = copy(dat) + expected["value"] = traits.Trait.data_keys['value'] + self.assertEqual(expected, traits.Trait.validate_input(dat)) + + # make sure extra values are cleaned if trait accepts no extras + dat = { + "name": "Test", + "trait_type": "trait", + "value": 10, + "extra_val1": 1000, + "extra_val2": "xvalue" + } + expected = copy(dat) + expected.pop("extra_val1") + expected.pop("extra_val2") + with patch.object(traits.Trait, "allow_extra_properties", False): + self.assertEqual(expected, traits.Trait.validate_input(dat)) + + def test_validate_input__fail(self): + """Test failing validation""" + dat = { + # missing name + "trait_type": "trait", + "value": 10, + "extra_val": 1000 + } + with self.assertRaises(traits.TraitException): + traits.Trait.validate_input(dat) + + # make value a required key + mock_data_keys = { + "value": traits.MandatoryTraitKey + } + with patch.object(traits.Trait, "data_keys", mock_data_keys): + dat = { + "name": "Trait", + "trait_type": "trait", + # missing value, now mandatory + "extra_val": 1000 + } + with self.assertRaises(traits.TraitException): + traits.Trait.validate_input(dat) + + def test_trait_getset(self): + """Get-set-del operations on trait""" + self.assertEqual(self.trait.name, "Test1") + self.assertEqual(self.trait['name'], "Test1") + self.assertEqual(self.trait.value, "value") + self.assertEqual(self.trait['value'], "value") + self.assertEqual(self.trait.extra_val1, "xvalue1" ) + self.assertEqual(self.trait['extra_val2'], "xvalue2") + + self.trait.value = 20 + self.assertEqual(self.trait['value'], 20) + self.trait['value'] = 20 + self.assertEqual(self.trait.value, 20) + self.trait.extra_val1 = 100 + self.assertEqual(self.trait.extra_val1, 100) + # additional properties + self.trait.foo = "bar" + self.assertEqual(self.trait.foo, "bar") + + del self.trait.foo + with self.assertRaises(KeyError): + self.trait['foo'] + with self.assertRaises(AttributeError): + self.trait.foo + del self.trait.extra_val1 + with self.assertRaises(AttributeError): + self.trait.extra_val1 + del self.trait.value + # fall back to default + self.assertTrue(self.trait.value == traits.Trait.data_keys["value"]) + + def test_repr(self): + self.assertEqual(repr(self.trait), Something) + self.assertEqual(str(self.trait), Something) + + +class TestTraitNumeric(_TraitHandlerBase): + + def test_trait__numeric(self): + self.traithandler.add( + "test2", + name="Test2", + trait_type='numeric', + ) + self.assertEqual( + self._get_dbstore("test2"), + {"name": "Test2", + "trait_type": 'numeric', + "base": 0, + } ) + def test_trait__static(self): + self.traithandler.add( + "test3", + name="Test3", + trait_type='static' + ) + self.assertEqual( + self._get_dbstore("test3"), + {"name": "Test3", + "trait_type": 'static', + "base": 0, + "mod": 0, + } + ) + def test_trait__counter(self): + self.traithandler.add( + "test4", + name="Test4", + trait_type='counter' + ) + self.assertEqual( + self._get_dbstore("test4"), + {"name": "Test4", + "trait_type": 'counter', + "base": 0, + "mod": 0, + "current": 0, + "max_value": None, + "min_value": None, + } + ) - + def test_trait__gauge(self): + self.traithandler.add( + "test5", + name="Test5", + trait_type='gauge' + ) + self.assertEqual( + self._get_dbstore("test5"), + {"name": "Test5", + "trait_type": 'gauge', + "base": 0, + "mod": 0, + "current": 0, + "max_value": None, + "min_value": None, + } + ) # # diff --git a/evennia/contrib/traits.py b/evennia/contrib/traits.py index 0f4ea96060..118b051c37 100644 --- a/evennia/contrib/traits.py +++ b/evennia/contrib/traits.py @@ -241,7 +241,7 @@ from django.conf import settings from functools import total_ordering from evennia.utils.dbserialize import _SaverDict from evennia.utils import logger -from evennia.utils.utils import inherits_from, class_from_module +from evennia.utils.utils import inherits_from, class_from_module, list_to_string # This way the user can easily supply their own. Each @@ -292,7 +292,7 @@ _SA = object.__setattr__ DEFAULT_TRAIT_TYPE = "static" -class TraitException(Exception): +class TraitException(RuntimeError): """ Base exception class raised by `Trait` objects. @@ -331,13 +331,15 @@ class TraitHandler: # load the available classes, if necessary _delayed_import_trait_classes() - # Note that this retains the connection to the database, meaning every + # Note that .trait_data retains the connection to the database, meaning every # update we do to .trait_data automatically syncs with database. self.trait_data = obj.attributes.get(db_attribute_key, category=db_attribute_category) if self.trait_data is None: - # no existing storage; initialize it + # no existing storage; initialize it, we then have to fetch it again + # to retain the db connection obj.attributes.add(db_attribute_key, {}, category=db_attribute_category) - self.trait_data = {} + self.trait_data = obj.attributes.get( + db_attribute_key, category=db_attribute_category) self._cache = {} def __len__(self): @@ -385,8 +387,7 @@ class TraitHandler: def get(self, key): """ Args: - trait (str): key from the traits dict containing config data - for the trait. "all" returns a list of all trait keys. + key (str): key from the traits dict containing config data. Returns: (`Trait` or `None`): named Trait class or None if trait key @@ -435,7 +436,7 @@ class TraitHandler: trait_class = _TRAIT_CLASSES.get(trait_type) if not trait_class: - raise TraitException("Trait-type '{trait_type} is invalid.") + raise TraitException(f"Trait-type '{trait_type}' is invalid.") trait_properties["name"] = key.title() if not name else name trait_properties["trait_type"] = trait_type @@ -443,10 +444,9 @@ class TraitHandler: # this will raise exception if input is insufficient trait_properties = trait_class.validate_input(trait_properties) - print("trait_properties", trait_properties) - self.trait_data[key] = trait_properties + def remove(self, key): """ Remove a Trait from the handler's parent object. @@ -474,7 +474,10 @@ class TraitHandler: class Trait: - """Represents an object or Character trait. + """Represents an object or Character trait. This simple base is just + storing anything in it's 'value' property, so it's pretty much just a + different wrapper to an Attribute. It does no type-checking of what is + stored. Note: See module docstring for configuration details. @@ -490,7 +493,7 @@ class Trait: # the trait will not be able to be created. # Apart from the keys given here, "name" and "trait_type" will also # always have to be a apart of the data. - data_keys = {} + data_keys = {"value": None} # enable to set/retrieve other arbitrary properties on the Trait # and have them treated like data to store. @@ -527,27 +530,47 @@ class Trait: """ Validate input + Args: + trait_data (dict or _SaverDict): Data to be used for + initialization of this trait. + Returns: + dict: Validated data, possibly complemented with default + values from data_keys. + Raises: + TraitException: If finding unset keys without a default. + """ - req = set(list(cls.data_keys.keys()) + ["name", "trait_type"]) + def _raise_err(unset_required): + """Helper method to format exception.""" + raise TraitException( + "Trait {} could not be created - misses required keys {}.".format( + cls.trait_type, list_to_string(list(unset_required), addquote=True) + ) + ) inp = set(trait_data.keys()) + + # separate check for name/trait_type, those are always required. + req = set(("name", "trait_type")) + unsets = req.difference(inp.intersection(req)) + if unsets: + _raise_err(unsets) + + # check other keys, these likely have defaults to fall back to + req = set(list(cls.data_keys.keys())) unsets = req.difference(inp.intersection(req)) unset_defaults = {key: cls.data_keys[key] for key in unsets} if MandatoryTraitKey in unset_defaults.values(): # we have one or more unset keys that was mandatory - unset_required = [key for key, value in unset_defaults.items() - if value == MandatoryTraitKey] - raise TraitException( - "Trait {} could not be created - misses required keys {}".format( - cls.trait_type, ", ".join(unset_required) - ) - ) + _raise_err([key for key, value in unset_defaults.items() + if value == MandatoryTraitKey]) # apply the default values trait_data.update(unset_defaults) if not cls.allow_extra_properties: # don't allow any extra properties - remove the extra data - for key in inp.difference(req) not in ("name", "trait_type"): + for key in (key for key in inp.difference(req) + if key not in ("name", "trait_type")): del trait_data[key] return trait_data @@ -577,7 +600,7 @@ class Trait: return self._data[key] except KeyError: raise AttributeError( - "{!r} {} ({}) has no attribute {!r}.".format( + "{!r} {} ({}) has no property {!r}.".format( self._data['name'], type(self).__name__, self.trait_type, @@ -612,8 +635,30 @@ class Trait: f"{self.trait_type} Trait.") def __delattr__(self, key): - """Delete extra parameters as attributes.""" - if key not in _GA(self, properties) and key in self._data: + """ + Delete or reset parameters. + + Args: + key (str): property-key to delete. + Raises: + TraitException: If trying to delete a data-key + without a default value to reset to. + Notes: + This will outright delete extra keys (if allow_extra_properties is + set). Keys in self.data_keys with a default value will be + reset to default. A data_key with a default of MandatoryDefaultKey + will raise a TraitException. Unfound matches will be silently ignored. + + """ + if key in self.data_keys: + if self.data_keys[key] == MandatoryTraitKey: + raise TraitException( + "Trait-Key {key} cannot be deleted: It's a mandatory property " + "with no default value to fall back to.") + # set to default + self._data[key] = self.data_keys[key] + elif key in self._data: + # an extra property. Delete as normal. del self._data[key] def __repr__(self): @@ -621,12 +666,12 @@ class Trait: return "{}({{{}}})".format( type(self).__name__, ", ".join( - ["'{}': {!r}".format(k, self._data[k]) for k in self._keys if k in self._data] + ["'{}': {!r}".format(k, self._data[k]) for k in self.data_keys if k in self._data] ), ) def __str__(self): - return f"" + return f"" # access properties @@ -637,13 +682,23 @@ class Trait: key = name + @property + def value(self): + """Store a value""" + return self._data["value"] + + @value.setter + def value(self, value): + """Get value""" + self._data["value"] = value + @total_ordering class NumericTrait(Trait): """ Base trait for all Traits based on numbers. This implements - number-comparisons, limits etc. It also features a "modifier" to the value, - since this is a common use. + number-comparisons, limits etc. It works on the 'base' property since this + makes more sense for child classes. """ @@ -652,6 +707,8 @@ class NumericTrait(Trait): data_keys = { "base": 0 } + def __str__(self): + return f"" # Numeric operations