Add unit tests for base traits

This commit is contained in:
Griatch 2020-04-18 14:43:10 +02:00
parent b442946d0f
commit c41fd0a33b
2 changed files with 407 additions and 48 deletions

View file

@ -6,13 +6,13 @@ Unit test module for Trait classes.
"""
from mock import MagicMock
from copy import copy
from anything import Something
from mock import MagicMock, patch
from django.test import TestCase
from django.test import override_settings
from evennia.utils.test_resources import EvenniaTest
from evennia.contrib.traits import (
TraitHandler, Trait, NumericTrait,
StaticTrait, CounterTrait, GaugeTrait)
from evennia.contrib import traits
class _MockObj:
@ -31,35 +31,337 @@ class _MockObj:
assert category == self.category
self.dbstore[key] = value
# we want to test the base traits too
_TEST_TRAIT_CLASS_PATHS = [
"evennia.contrib.traits.Trait",
"evennia.contrib.traits.NumericTrait",
"evennia.contrib.traits.StaticTrait",
"evennia.contrib.traits.CounterTrait",
"evennia.contrib.traits.GaugeTrait",
]
@override_settings(TRAIT_CLASS_PATHS=["evennia.contrib.traits.Trait"])
class TraitHandlerTest(TestCase):
"""Test case for TraitHandler"""
class _TraitHandlerBase(TestCase):
"Base for trait tests"
@patch("evennia.contrib.traits._TRAIT_CLASS_PATHS", new=_TEST_TRAIT_CLASS_PATHS)
def setUp(self):
self.obj = _MockObj()
self.traithandler = TraitHandler(self.obj)
self.traithandler = traits.TraitHandler(self.obj)
self.obj.traits = self.traithandler
def _get_dbstore(self, key):
return self.obj.dbstore['traits'][key]
class TraitHandlerTest(_TraitHandlerBase):
"""Testing for TraitHandler"""
def setUp(self):
super().setUp()
self.traithandler.add(
"test1",
name="Test1",
trait_type='trait'
)
self.traithandler.add(
"test2",
name="Test2",
trait_type='trait',
value=["foo", {"1": [1, 2, 3]}, 4],
)
def test_add_trait(self):
self.traithandler.add(
"test",
name="Test",
trait_type="trait"
)
self.assertEqual(
self.obj.dbstore["test"],
{"name": "Test",
"trait_type": "trait"}
self._get_dbstore("test1"),
{"name": "Test1",
"trait_type": 'trait',
"value": None,
}
)
self.assertEqual(
self._get_dbstore("test2"),
{"name": "Test2",
"trait_type": 'trait',
"value": ["foo", {"1": [1, 2, 3]}, 4],
}
)
self.assertEqual(len(self.traithandler), 2)
def test_cache(self):
"""
Cache should not be set until first get
"""
self.assertEqual(len(self.traithandler._cache), 0)
self.traithandler.all # does not affect cache
self.assertEqual(len(self.traithandler._cache), 0)
self.traithandler.test1
self.assertEqual(len(self.traithandler._cache), 1)
self.traithandler.test2
self.assertEqual(len(self.traithandler._cache), 2)
def test_setting(self):
"Don't allow setting stuff on traithandler"
with self.assertRaises(traits.TraitException):
self.traithandler.foo = "bar"
with self.assertRaises(traits.TraitException):
self.traithandler["foo"] = "bar"
def test_getting(self):
"Test we are getting data from the dbstore"
self.assertEqual(
self.traithandler.test1._data,
{"name": "Test1", "trait_type": "trait",
"value": None}
)
self.assertEqual(
self.traithandler._cache, Something
)
self.assertEqual(
self.traithandler.test2._data,
{"name": "Test2", "trait_type": "trait",
"value": ["foo", {"1": [1, 2, 3]}, 4]}
)
self.assertEqual(
self.traithandler._cache, Something
)
self.assertFalse(self.traithandler.get("foo"))
self.assertFalse(self.traithandler.bar)
def test_all(self):
"Test all method"
self.assertEqual(self.traithandler.all, ["test1", "test2"])
def test_remove(self):
"Test remove method"
self.traithandler.remove("test2")
self.assertEqual(len(self.traithandler), 1)
self.assertTrue(bool(self.traithandler.get("test1"))) # this populates cache
self.assertEqual(len(self.traithandler._cache), 1)
with self.assertRaises(traits.TraitException):
self.traithandler.remove("foo")
def test_clear(self):
"Test clear method"
self.traithandler.clear()
self.assertEqual(len(self.traithandler), 0)
def test_trait_db_connection(self):
"Test that updating a trait property actually updates value in db"
trait = self.traithandler.test1
self.assertEqual(trait.value, None)
trait.value = 10
self.assertEqual(trait.value, 10)
self.assertEqual(
self.obj.attributes.get("traits", category="traits")['test1']['value'],
10
)
trait.value = 20
self.assertEqual(trait.value, 20)
self.assertEqual(
self.obj.attributes.get("traits", category="traits")['test1']['value'],
20
)
del trait.value
self.assertEqual(
self.obj.attributes.get("traits", category="traits")['test1']['value'],
None
)
class TraitTest(_TraitHandlerBase):
"""
Test the base Trait class
"""
def setUp(self):
super().setUp()
self.traithandler.add(
"test1",
name="Test1",
trait_type="trait",
value="value",
extra_val1="xvalue1",
extra_val2="xvalue2",
)
self.trait = self.traithandler.get("test1")
def test_init(self):
self.assertEqual(
self.trait._data,
{"name": "Test1",
"trait_type": "trait",
"value": "value",
"extra_val1": "xvalue1",
"extra_val2": "xvalue2"
}
)
def test_validate_input__valid(self):
"""Test valid validation input"""
# all data supplied, and extras
dat = {
"name": "Test",
"trait_type": "trait",
"value": 10,
"extra_val": 1000
}
expected = copy(dat) # we must break link or return === dat always
self.assertEqual(expected, traits.Trait.validate_input(dat))
# don't supply value, should get default
dat = {
"name": "Test",
"trait_type": "trait",
# missing value
"extra_val": 1000
}
expected = copy(dat)
expected["value"] = traits.Trait.data_keys['value']
self.assertEqual(expected, traits.Trait.validate_input(dat))
# make sure extra values are cleaned if trait accepts no extras
dat = {
"name": "Test",
"trait_type": "trait",
"value": 10,
"extra_val1": 1000,
"extra_val2": "xvalue"
}
expected = copy(dat)
expected.pop("extra_val1")
expected.pop("extra_val2")
with patch.object(traits.Trait, "allow_extra_properties", False):
self.assertEqual(expected, traits.Trait.validate_input(dat))
def test_validate_input__fail(self):
"""Test failing validation"""
dat = {
# missing name
"trait_type": "trait",
"value": 10,
"extra_val": 1000
}
with self.assertRaises(traits.TraitException):
traits.Trait.validate_input(dat)
# make value a required key
mock_data_keys = {
"value": traits.MandatoryTraitKey
}
with patch.object(traits.Trait, "data_keys", mock_data_keys):
dat = {
"name": "Trait",
"trait_type": "trait",
# missing value, now mandatory
"extra_val": 1000
}
with self.assertRaises(traits.TraitException):
traits.Trait.validate_input(dat)
def test_trait_getset(self):
"""Get-set-del operations on trait"""
self.assertEqual(self.trait.name, "Test1")
self.assertEqual(self.trait['name'], "Test1")
self.assertEqual(self.trait.value, "value")
self.assertEqual(self.trait['value'], "value")
self.assertEqual(self.trait.extra_val1, "xvalue1" )
self.assertEqual(self.trait['extra_val2'], "xvalue2")
self.trait.value = 20
self.assertEqual(self.trait['value'], 20)
self.trait['value'] = 20
self.assertEqual(self.trait.value, 20)
self.trait.extra_val1 = 100
self.assertEqual(self.trait.extra_val1, 100)
# additional properties
self.trait.foo = "bar"
self.assertEqual(self.trait.foo, "bar")
del self.trait.foo
with self.assertRaises(KeyError):
self.trait['foo']
with self.assertRaises(AttributeError):
self.trait.foo
del self.trait.extra_val1
with self.assertRaises(AttributeError):
self.trait.extra_val1
del self.trait.value
# fall back to default
self.assertTrue(self.trait.value == traits.Trait.data_keys["value"])
def test_repr(self):
self.assertEqual(repr(self.trait), Something)
self.assertEqual(str(self.trait), Something)
class TestTraitNumeric(_TraitHandlerBase):
def test_trait__numeric(self):
self.traithandler.add(
"test2",
name="Test2",
trait_type='numeric',
)
self.assertEqual(
self._get_dbstore("test2"),
{"name": "Test2",
"trait_type": 'numeric',
"base": 0,
}
)
def test_trait__static(self):
self.traithandler.add(
"test3",
name="Test3",
trait_type='static'
)
self.assertEqual(
self._get_dbstore("test3"),
{"name": "Test3",
"trait_type": 'static',
"base": 0,
"mod": 0,
}
)
def test_trait__counter(self):
self.traithandler.add(
"test4",
name="Test4",
trait_type='counter'
)
self.assertEqual(
self._get_dbstore("test4"),
{"name": "Test4",
"trait_type": 'counter',
"base": 0,
"mod": 0,
"current": 0,
"max_value": None,
"min_value": None,
}
)
def test_trait__gauge(self):
self.traithandler.add(
"test5",
name="Test5",
trait_type='gauge'
)
self.assertEqual(
self._get_dbstore("test5"),
{"name": "Test5",
"trait_type": 'gauge',
"base": 0,
"mod": 0,
"current": 0,
"max_value": None,
"min_value": None,
}
)
#
#

View file

@ -241,7 +241,7 @@ from django.conf import settings
from functools import total_ordering
from evennia.utils.dbserialize import _SaverDict
from evennia.utils import logger
from evennia.utils.utils import inherits_from, class_from_module
from evennia.utils.utils import inherits_from, class_from_module, list_to_string
# This way the user can easily supply their own. Each
@ -292,7 +292,7 @@ _SA = object.__setattr__
DEFAULT_TRAIT_TYPE = "static"
class TraitException(Exception):
class TraitException(RuntimeError):
"""
Base exception class raised by `Trait` objects.
@ -331,13 +331,15 @@ class TraitHandler:
# load the available classes, if necessary
_delayed_import_trait_classes()
# Note that this retains the connection to the database, meaning every
# Note that .trait_data retains the connection to the database, meaning every
# update we do to .trait_data automatically syncs with database.
self.trait_data = obj.attributes.get(db_attribute_key, category=db_attribute_category)
if self.trait_data is None:
# no existing storage; initialize it
# no existing storage; initialize it, we then have to fetch it again
# to retain the db connection
obj.attributes.add(db_attribute_key, {}, category=db_attribute_category)
self.trait_data = {}
self.trait_data = obj.attributes.get(
db_attribute_key, category=db_attribute_category)
self._cache = {}
def __len__(self):
@ -385,8 +387,7 @@ class TraitHandler:
def get(self, key):
"""
Args:
trait (str): key from the traits dict containing config data
for the trait. "all" returns a list of all trait keys.
key (str): key from the traits dict containing config data.
Returns:
(`Trait` or `None`): named Trait class or None if trait key
@ -435,7 +436,7 @@ class TraitHandler:
trait_class = _TRAIT_CLASSES.get(trait_type)
if not trait_class:
raise TraitException("Trait-type '{trait_type} is invalid.")
raise TraitException(f"Trait-type '{trait_type}' is invalid.")
trait_properties["name"] = key.title() if not name else name
trait_properties["trait_type"] = trait_type
@ -443,10 +444,9 @@ class TraitHandler:
# this will raise exception if input is insufficient
trait_properties = trait_class.validate_input(trait_properties)
print("trait_properties", trait_properties)
self.trait_data[key] = trait_properties
def remove(self, key):
"""
Remove a Trait from the handler's parent object.
@ -474,7 +474,10 @@ class TraitHandler:
class Trait:
"""Represents an object or Character trait.
"""Represents an object or Character trait. This simple base is just
storing anything in it's 'value' property, so it's pretty much just a
different wrapper to an Attribute. It does no type-checking of what is
stored.
Note:
See module docstring for configuration details.
@ -490,7 +493,7 @@ class Trait:
# the trait will not be able to be created.
# Apart from the keys given here, "name" and "trait_type" will also
# always have to be a apart of the data.
data_keys = {}
data_keys = {"value": None}
# enable to set/retrieve other arbitrary properties on the Trait
# and have them treated like data to store.
@ -527,27 +530,47 @@ class Trait:
"""
Validate input
Args:
trait_data (dict or _SaverDict): Data to be used for
initialization of this trait.
Returns:
dict: Validated data, possibly complemented with default
values from data_keys.
Raises:
TraitException: If finding unset keys without a default.
"""
req = set(list(cls.data_keys.keys()) + ["name", "trait_type"])
def _raise_err(unset_required):
"""Helper method to format exception."""
raise TraitException(
"Trait {} could not be created - misses required keys {}.".format(
cls.trait_type, list_to_string(list(unset_required), addquote=True)
)
)
inp = set(trait_data.keys())
# separate check for name/trait_type, those are always required.
req = set(("name", "trait_type"))
unsets = req.difference(inp.intersection(req))
if unsets:
_raise_err(unsets)
# check other keys, these likely have defaults to fall back to
req = set(list(cls.data_keys.keys()))
unsets = req.difference(inp.intersection(req))
unset_defaults = {key: cls.data_keys[key] for key in unsets}
if MandatoryTraitKey in unset_defaults.values():
# we have one or more unset keys that was mandatory
unset_required = [key for key, value in unset_defaults.items()
if value == MandatoryTraitKey]
raise TraitException(
"Trait {} could not be created - misses required keys {}".format(
cls.trait_type, ", ".join(unset_required)
)
)
_raise_err([key for key, value in unset_defaults.items()
if value == MandatoryTraitKey])
# apply the default values
trait_data.update(unset_defaults)
if not cls.allow_extra_properties:
# don't allow any extra properties - remove the extra data
for key in inp.difference(req) not in ("name", "trait_type"):
for key in (key for key in inp.difference(req)
if key not in ("name", "trait_type")):
del trait_data[key]
return trait_data
@ -577,7 +600,7 @@ class Trait:
return self._data[key]
except KeyError:
raise AttributeError(
"{!r} {} ({}) has no attribute {!r}.".format(
"{!r} {} ({}) has no property {!r}.".format(
self._data['name'],
type(self).__name__,
self.trait_type,
@ -612,8 +635,30 @@ class Trait:
f"{self.trait_type} Trait.")
def __delattr__(self, key):
"""Delete extra parameters as attributes."""
if key not in _GA(self, properties) and key in self._data:
"""
Delete or reset parameters.
Args:
key (str): property-key to delete.
Raises:
TraitException: If trying to delete a data-key
without a default value to reset to.
Notes:
This will outright delete extra keys (if allow_extra_properties is
set). Keys in self.data_keys with a default value will be
reset to default. A data_key with a default of MandatoryDefaultKey
will raise a TraitException. Unfound matches will be silently ignored.
"""
if key in self.data_keys:
if self.data_keys[key] == MandatoryTraitKey:
raise TraitException(
"Trait-Key {key} cannot be deleted: It's a mandatory property "
"with no default value to fall back to.")
# set to default
self._data[key] = self.data_keys[key]
elif key in self._data:
# an extra property. Delete as normal.
del self._data[key]
def __repr__(self):
@ -621,12 +666,12 @@ class Trait:
return "{}({{{}}})".format(
type(self).__name__,
", ".join(
["'{}': {!r}".format(k, self._data[k]) for k in self._keys if k in self._data]
["'{}': {!r}".format(k, self._data[k]) for k in self.data_keys if k in self._data]
),
)
def __str__(self):
return f"<Trait {self.name}>"
return f"<Trait {self.name}: {self._data['value']}>"
# access properties
@ -637,13 +682,23 @@ class Trait:
key = name
@property
def value(self):
"""Store a value"""
return self._data["value"]
@value.setter
def value(self, value):
"""Get value"""
self._data["value"] = value
@total_ordering
class NumericTrait(Trait):
"""
Base trait for all Traits based on numbers. This implements
number-comparisons, limits etc. It also features a "modifier" to the value,
since this is a common use.
number-comparisons, limits etc. It works on the 'base' property since this
makes more sense for child classes.
"""
@ -652,6 +707,8 @@ class NumericTrait(Trait):
data_keys = {
"base": 0
}
def __str__(self):
return f"<Trait {self.name}: {self._data['base']}>"
# Numeric operations