mirror of
https://github.com/evennia/evennia.git
synced 2026-03-30 20:47:17 +02:00
Add unit tests for base traits
This commit is contained in:
parent
b442946d0f
commit
c41fd0a33b
2 changed files with 407 additions and 48 deletions
|
|
@ -6,13 +6,13 @@ Unit test module for Trait classes.
|
|||
|
||||
"""
|
||||
|
||||
from mock import MagicMock
|
||||
from copy import copy
|
||||
from anything import Something
|
||||
from mock import MagicMock, patch
|
||||
from django.test import TestCase
|
||||
from django.test import override_settings
|
||||
from evennia.utils.test_resources import EvenniaTest
|
||||
from evennia.contrib.traits import (
|
||||
TraitHandler, Trait, NumericTrait,
|
||||
StaticTrait, CounterTrait, GaugeTrait)
|
||||
from evennia.contrib import traits
|
||||
|
||||
|
||||
class _MockObj:
|
||||
|
|
@ -31,35 +31,337 @@ class _MockObj:
|
|||
assert category == self.category
|
||||
self.dbstore[key] = value
|
||||
|
||||
# we want to test the base traits too
|
||||
_TEST_TRAIT_CLASS_PATHS = [
|
||||
"evennia.contrib.traits.Trait",
|
||||
"evennia.contrib.traits.NumericTrait",
|
||||
"evennia.contrib.traits.StaticTrait",
|
||||
"evennia.contrib.traits.CounterTrait",
|
||||
"evennia.contrib.traits.GaugeTrait",
|
||||
]
|
||||
|
||||
@override_settings(TRAIT_CLASS_PATHS=["evennia.contrib.traits.Trait"])
|
||||
class TraitHandlerTest(TestCase):
|
||||
"""Test case for TraitHandler"""
|
||||
|
||||
class _TraitHandlerBase(TestCase):
|
||||
"Base for trait tests"
|
||||
@patch("evennia.contrib.traits._TRAIT_CLASS_PATHS", new=_TEST_TRAIT_CLASS_PATHS)
|
||||
def setUp(self):
|
||||
self.obj = _MockObj()
|
||||
self.traithandler = TraitHandler(self.obj)
|
||||
self.traithandler = traits.TraitHandler(self.obj)
|
||||
self.obj.traits = self.traithandler
|
||||
|
||||
def _get_dbstore(self, key):
|
||||
return self.obj.dbstore['traits'][key]
|
||||
|
||||
|
||||
class TraitHandlerTest(_TraitHandlerBase):
|
||||
"""Testing for TraitHandler"""
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.traithandler.add(
|
||||
"test1",
|
||||
name="Test1",
|
||||
trait_type='trait'
|
||||
)
|
||||
self.traithandler.add(
|
||||
"test2",
|
||||
name="Test2",
|
||||
trait_type='trait',
|
||||
value=["foo", {"1": [1, 2, 3]}, 4],
|
||||
)
|
||||
|
||||
def test_add_trait(self):
|
||||
|
||||
self.traithandler.add(
|
||||
"test",
|
||||
name="Test",
|
||||
trait_type="trait"
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
self.obj.dbstore["test"],
|
||||
{"name": "Test",
|
||||
"trait_type": "trait"}
|
||||
self._get_dbstore("test1"),
|
||||
{"name": "Test1",
|
||||
"trait_type": 'trait',
|
||||
"value": None,
|
||||
}
|
||||
)
|
||||
self.assertEqual(
|
||||
self._get_dbstore("test2"),
|
||||
{"name": "Test2",
|
||||
"trait_type": 'trait',
|
||||
"value": ["foo", {"1": [1, 2, 3]}, 4],
|
||||
}
|
||||
)
|
||||
self.assertEqual(len(self.traithandler), 2)
|
||||
|
||||
def test_cache(self):
|
||||
"""
|
||||
Cache should not be set until first get
|
||||
"""
|
||||
self.assertEqual(len(self.traithandler._cache), 0)
|
||||
self.traithandler.all # does not affect cache
|
||||
self.assertEqual(len(self.traithandler._cache), 0)
|
||||
self.traithandler.test1
|
||||
self.assertEqual(len(self.traithandler._cache), 1)
|
||||
self.traithandler.test2
|
||||
self.assertEqual(len(self.traithandler._cache), 2)
|
||||
|
||||
def test_setting(self):
|
||||
"Don't allow setting stuff on traithandler"
|
||||
with self.assertRaises(traits.TraitException):
|
||||
self.traithandler.foo = "bar"
|
||||
with self.assertRaises(traits.TraitException):
|
||||
self.traithandler["foo"] = "bar"
|
||||
|
||||
def test_getting(self):
|
||||
"Test we are getting data from the dbstore"
|
||||
self.assertEqual(
|
||||
self.traithandler.test1._data,
|
||||
{"name": "Test1", "trait_type": "trait",
|
||||
"value": None}
|
||||
)
|
||||
self.assertEqual(
|
||||
self.traithandler._cache, Something
|
||||
)
|
||||
self.assertEqual(
|
||||
self.traithandler.test2._data,
|
||||
{"name": "Test2", "trait_type": "trait",
|
||||
"value": ["foo", {"1": [1, 2, 3]}, 4]}
|
||||
)
|
||||
self.assertEqual(
|
||||
self.traithandler._cache, Something
|
||||
)
|
||||
self.assertFalse(self.traithandler.get("foo"))
|
||||
self.assertFalse(self.traithandler.bar)
|
||||
|
||||
def test_all(self):
|
||||
"Test all method"
|
||||
self.assertEqual(self.traithandler.all, ["test1", "test2"])
|
||||
|
||||
def test_remove(self):
|
||||
"Test remove method"
|
||||
self.traithandler.remove("test2")
|
||||
self.assertEqual(len(self.traithandler), 1)
|
||||
self.assertTrue(bool(self.traithandler.get("test1"))) # this populates cache
|
||||
self.assertEqual(len(self.traithandler._cache), 1)
|
||||
with self.assertRaises(traits.TraitException):
|
||||
self.traithandler.remove("foo")
|
||||
|
||||
def test_clear(self):
|
||||
"Test clear method"
|
||||
self.traithandler.clear()
|
||||
self.assertEqual(len(self.traithandler), 0)
|
||||
|
||||
def test_trait_db_connection(self):
|
||||
"Test that updating a trait property actually updates value in db"
|
||||
trait = self.traithandler.test1
|
||||
self.assertEqual(trait.value, None)
|
||||
trait.value = 10
|
||||
self.assertEqual(trait.value, 10)
|
||||
self.assertEqual(
|
||||
self.obj.attributes.get("traits", category="traits")['test1']['value'],
|
||||
10
|
||||
)
|
||||
trait.value = 20
|
||||
self.assertEqual(trait.value, 20)
|
||||
self.assertEqual(
|
||||
self.obj.attributes.get("traits", category="traits")['test1']['value'],
|
||||
20
|
||||
)
|
||||
del trait.value
|
||||
self.assertEqual(
|
||||
self.obj.attributes.get("traits", category="traits")['test1']['value'],
|
||||
None
|
||||
)
|
||||
|
||||
|
||||
class TraitTest(_TraitHandlerBase):
|
||||
"""
|
||||
Test the base Trait class
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.traithandler.add(
|
||||
"test1",
|
||||
name="Test1",
|
||||
trait_type="trait",
|
||||
value="value",
|
||||
extra_val1="xvalue1",
|
||||
extra_val2="xvalue2",
|
||||
)
|
||||
self.trait = self.traithandler.get("test1")
|
||||
|
||||
def test_init(self):
|
||||
self.assertEqual(
|
||||
self.trait._data,
|
||||
{"name": "Test1",
|
||||
"trait_type": "trait",
|
||||
"value": "value",
|
||||
"extra_val1": "xvalue1",
|
||||
"extra_val2": "xvalue2"
|
||||
}
|
||||
)
|
||||
|
||||
def test_validate_input__valid(self):
|
||||
"""Test valid validation input"""
|
||||
# all data supplied, and extras
|
||||
dat = {
|
||||
"name": "Test",
|
||||
"trait_type": "trait",
|
||||
"value": 10,
|
||||
"extra_val": 1000
|
||||
}
|
||||
expected = copy(dat) # we must break link or return === dat always
|
||||
self.assertEqual(expected, traits.Trait.validate_input(dat))
|
||||
|
||||
# don't supply value, should get default
|
||||
dat = {
|
||||
"name": "Test",
|
||||
"trait_type": "trait",
|
||||
# missing value
|
||||
"extra_val": 1000
|
||||
}
|
||||
expected = copy(dat)
|
||||
expected["value"] = traits.Trait.data_keys['value']
|
||||
self.assertEqual(expected, traits.Trait.validate_input(dat))
|
||||
|
||||
# make sure extra values are cleaned if trait accepts no extras
|
||||
dat = {
|
||||
"name": "Test",
|
||||
"trait_type": "trait",
|
||||
"value": 10,
|
||||
"extra_val1": 1000,
|
||||
"extra_val2": "xvalue"
|
||||
}
|
||||
expected = copy(dat)
|
||||
expected.pop("extra_val1")
|
||||
expected.pop("extra_val2")
|
||||
with patch.object(traits.Trait, "allow_extra_properties", False):
|
||||
self.assertEqual(expected, traits.Trait.validate_input(dat))
|
||||
|
||||
def test_validate_input__fail(self):
|
||||
"""Test failing validation"""
|
||||
dat = {
|
||||
# missing name
|
||||
"trait_type": "trait",
|
||||
"value": 10,
|
||||
"extra_val": 1000
|
||||
}
|
||||
with self.assertRaises(traits.TraitException):
|
||||
traits.Trait.validate_input(dat)
|
||||
|
||||
# make value a required key
|
||||
mock_data_keys = {
|
||||
"value": traits.MandatoryTraitKey
|
||||
}
|
||||
with patch.object(traits.Trait, "data_keys", mock_data_keys):
|
||||
dat = {
|
||||
"name": "Trait",
|
||||
"trait_type": "trait",
|
||||
# missing value, now mandatory
|
||||
"extra_val": 1000
|
||||
}
|
||||
with self.assertRaises(traits.TraitException):
|
||||
traits.Trait.validate_input(dat)
|
||||
|
||||
def test_trait_getset(self):
|
||||
"""Get-set-del operations on trait"""
|
||||
self.assertEqual(self.trait.name, "Test1")
|
||||
self.assertEqual(self.trait['name'], "Test1")
|
||||
self.assertEqual(self.trait.value, "value")
|
||||
self.assertEqual(self.trait['value'], "value")
|
||||
self.assertEqual(self.trait.extra_val1, "xvalue1" )
|
||||
self.assertEqual(self.trait['extra_val2'], "xvalue2")
|
||||
|
||||
self.trait.value = 20
|
||||
self.assertEqual(self.trait['value'], 20)
|
||||
self.trait['value'] = 20
|
||||
self.assertEqual(self.trait.value, 20)
|
||||
self.trait.extra_val1 = 100
|
||||
self.assertEqual(self.trait.extra_val1, 100)
|
||||
# additional properties
|
||||
self.trait.foo = "bar"
|
||||
self.assertEqual(self.trait.foo, "bar")
|
||||
|
||||
del self.trait.foo
|
||||
with self.assertRaises(KeyError):
|
||||
self.trait['foo']
|
||||
with self.assertRaises(AttributeError):
|
||||
self.trait.foo
|
||||
del self.trait.extra_val1
|
||||
with self.assertRaises(AttributeError):
|
||||
self.trait.extra_val1
|
||||
del self.trait.value
|
||||
# fall back to default
|
||||
self.assertTrue(self.trait.value == traits.Trait.data_keys["value"])
|
||||
|
||||
def test_repr(self):
|
||||
self.assertEqual(repr(self.trait), Something)
|
||||
self.assertEqual(str(self.trait), Something)
|
||||
|
||||
|
||||
class TestTraitNumeric(_TraitHandlerBase):
|
||||
|
||||
def test_trait__numeric(self):
|
||||
self.traithandler.add(
|
||||
"test2",
|
||||
name="Test2",
|
||||
trait_type='numeric',
|
||||
)
|
||||
self.assertEqual(
|
||||
self._get_dbstore("test2"),
|
||||
{"name": "Test2",
|
||||
"trait_type": 'numeric',
|
||||
"base": 0,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
def test_trait__static(self):
|
||||
self.traithandler.add(
|
||||
"test3",
|
||||
name="Test3",
|
||||
trait_type='static'
|
||||
)
|
||||
self.assertEqual(
|
||||
self._get_dbstore("test3"),
|
||||
{"name": "Test3",
|
||||
"trait_type": 'static',
|
||||
"base": 0,
|
||||
"mod": 0,
|
||||
}
|
||||
)
|
||||
|
||||
def test_trait__counter(self):
|
||||
self.traithandler.add(
|
||||
"test4",
|
||||
name="Test4",
|
||||
trait_type='counter'
|
||||
)
|
||||
self.assertEqual(
|
||||
self._get_dbstore("test4"),
|
||||
{"name": "Test4",
|
||||
"trait_type": 'counter',
|
||||
"base": 0,
|
||||
"mod": 0,
|
||||
"current": 0,
|
||||
"max_value": None,
|
||||
"min_value": None,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def test_trait__gauge(self):
|
||||
self.traithandler.add(
|
||||
"test5",
|
||||
name="Test5",
|
||||
trait_type='gauge'
|
||||
)
|
||||
self.assertEqual(
|
||||
self._get_dbstore("test5"),
|
||||
{"name": "Test5",
|
||||
"trait_type": 'gauge',
|
||||
"base": 0,
|
||||
"mod": 0,
|
||||
"current": 0,
|
||||
"max_value": None,
|
||||
"min_value": None,
|
||||
}
|
||||
)
|
||||
|
||||
#
|
||||
#
|
||||
|
|
|
|||
|
|
@ -241,7 +241,7 @@ from django.conf import settings
|
|||
from functools import total_ordering
|
||||
from evennia.utils.dbserialize import _SaverDict
|
||||
from evennia.utils import logger
|
||||
from evennia.utils.utils import inherits_from, class_from_module
|
||||
from evennia.utils.utils import inherits_from, class_from_module, list_to_string
|
||||
|
||||
|
||||
# This way the user can easily supply their own. Each
|
||||
|
|
@ -292,7 +292,7 @@ _SA = object.__setattr__
|
|||
DEFAULT_TRAIT_TYPE = "static"
|
||||
|
||||
|
||||
class TraitException(Exception):
|
||||
class TraitException(RuntimeError):
|
||||
"""
|
||||
Base exception class raised by `Trait` objects.
|
||||
|
||||
|
|
@ -331,13 +331,15 @@ class TraitHandler:
|
|||
# load the available classes, if necessary
|
||||
_delayed_import_trait_classes()
|
||||
|
||||
# Note that this retains the connection to the database, meaning every
|
||||
# Note that .trait_data retains the connection to the database, meaning every
|
||||
# update we do to .trait_data automatically syncs with database.
|
||||
self.trait_data = obj.attributes.get(db_attribute_key, category=db_attribute_category)
|
||||
if self.trait_data is None:
|
||||
# no existing storage; initialize it
|
||||
# no existing storage; initialize it, we then have to fetch it again
|
||||
# to retain the db connection
|
||||
obj.attributes.add(db_attribute_key, {}, category=db_attribute_category)
|
||||
self.trait_data = {}
|
||||
self.trait_data = obj.attributes.get(
|
||||
db_attribute_key, category=db_attribute_category)
|
||||
self._cache = {}
|
||||
|
||||
def __len__(self):
|
||||
|
|
@ -385,8 +387,7 @@ class TraitHandler:
|
|||
def get(self, key):
|
||||
"""
|
||||
Args:
|
||||
trait (str): key from the traits dict containing config data
|
||||
for the trait. "all" returns a list of all trait keys.
|
||||
key (str): key from the traits dict containing config data.
|
||||
|
||||
Returns:
|
||||
(`Trait` or `None`): named Trait class or None if trait key
|
||||
|
|
@ -435,7 +436,7 @@ class TraitHandler:
|
|||
|
||||
trait_class = _TRAIT_CLASSES.get(trait_type)
|
||||
if not trait_class:
|
||||
raise TraitException("Trait-type '{trait_type} is invalid.")
|
||||
raise TraitException(f"Trait-type '{trait_type}' is invalid.")
|
||||
|
||||
trait_properties["name"] = key.title() if not name else name
|
||||
trait_properties["trait_type"] = trait_type
|
||||
|
|
@ -443,10 +444,9 @@ class TraitHandler:
|
|||
# this will raise exception if input is insufficient
|
||||
trait_properties = trait_class.validate_input(trait_properties)
|
||||
|
||||
print("trait_properties", trait_properties)
|
||||
|
||||
self.trait_data[key] = trait_properties
|
||||
|
||||
|
||||
def remove(self, key):
|
||||
"""
|
||||
Remove a Trait from the handler's parent object.
|
||||
|
|
@ -474,7 +474,10 @@ class TraitHandler:
|
|||
|
||||
|
||||
class Trait:
|
||||
"""Represents an object or Character trait.
|
||||
"""Represents an object or Character trait. This simple base is just
|
||||
storing anything in it's 'value' property, so it's pretty much just a
|
||||
different wrapper to an Attribute. It does no type-checking of what is
|
||||
stored.
|
||||
|
||||
Note:
|
||||
See module docstring for configuration details.
|
||||
|
|
@ -490,7 +493,7 @@ class Trait:
|
|||
# the trait will not be able to be created.
|
||||
# Apart from the keys given here, "name" and "trait_type" will also
|
||||
# always have to be a apart of the data.
|
||||
data_keys = {}
|
||||
data_keys = {"value": None}
|
||||
|
||||
# enable to set/retrieve other arbitrary properties on the Trait
|
||||
# and have them treated like data to store.
|
||||
|
|
@ -527,27 +530,47 @@ class Trait:
|
|||
"""
|
||||
Validate input
|
||||
|
||||
Args:
|
||||
trait_data (dict or _SaverDict): Data to be used for
|
||||
initialization of this trait.
|
||||
Returns:
|
||||
dict: Validated data, possibly complemented with default
|
||||
values from data_keys.
|
||||
Raises:
|
||||
TraitException: If finding unset keys without a default.
|
||||
|
||||
"""
|
||||
req = set(list(cls.data_keys.keys()) + ["name", "trait_type"])
|
||||
def _raise_err(unset_required):
|
||||
"""Helper method to format exception."""
|
||||
raise TraitException(
|
||||
"Trait {} could not be created - misses required keys {}.".format(
|
||||
cls.trait_type, list_to_string(list(unset_required), addquote=True)
|
||||
)
|
||||
)
|
||||
inp = set(trait_data.keys())
|
||||
|
||||
# separate check for name/trait_type, those are always required.
|
||||
req = set(("name", "trait_type"))
|
||||
unsets = req.difference(inp.intersection(req))
|
||||
if unsets:
|
||||
_raise_err(unsets)
|
||||
|
||||
# check other keys, these likely have defaults to fall back to
|
||||
req = set(list(cls.data_keys.keys()))
|
||||
unsets = req.difference(inp.intersection(req))
|
||||
unset_defaults = {key: cls.data_keys[key] for key in unsets}
|
||||
|
||||
if MandatoryTraitKey in unset_defaults.values():
|
||||
# we have one or more unset keys that was mandatory
|
||||
unset_required = [key for key, value in unset_defaults.items()
|
||||
if value == MandatoryTraitKey]
|
||||
raise TraitException(
|
||||
"Trait {} could not be created - misses required keys {}".format(
|
||||
cls.trait_type, ", ".join(unset_required)
|
||||
)
|
||||
)
|
||||
_raise_err([key for key, value in unset_defaults.items()
|
||||
if value == MandatoryTraitKey])
|
||||
# apply the default values
|
||||
trait_data.update(unset_defaults)
|
||||
|
||||
if not cls.allow_extra_properties:
|
||||
# don't allow any extra properties - remove the extra data
|
||||
for key in inp.difference(req) not in ("name", "trait_type"):
|
||||
for key in (key for key in inp.difference(req)
|
||||
if key not in ("name", "trait_type")):
|
||||
del trait_data[key]
|
||||
|
||||
return trait_data
|
||||
|
|
@ -577,7 +600,7 @@ class Trait:
|
|||
return self._data[key]
|
||||
except KeyError:
|
||||
raise AttributeError(
|
||||
"{!r} {} ({}) has no attribute {!r}.".format(
|
||||
"{!r} {} ({}) has no property {!r}.".format(
|
||||
self._data['name'],
|
||||
type(self).__name__,
|
||||
self.trait_type,
|
||||
|
|
@ -612,8 +635,30 @@ class Trait:
|
|||
f"{self.trait_type} Trait.")
|
||||
|
||||
def __delattr__(self, key):
|
||||
"""Delete extra parameters as attributes."""
|
||||
if key not in _GA(self, properties) and key in self._data:
|
||||
"""
|
||||
Delete or reset parameters.
|
||||
|
||||
Args:
|
||||
key (str): property-key to delete.
|
||||
Raises:
|
||||
TraitException: If trying to delete a data-key
|
||||
without a default value to reset to.
|
||||
Notes:
|
||||
This will outright delete extra keys (if allow_extra_properties is
|
||||
set). Keys in self.data_keys with a default value will be
|
||||
reset to default. A data_key with a default of MandatoryDefaultKey
|
||||
will raise a TraitException. Unfound matches will be silently ignored.
|
||||
|
||||
"""
|
||||
if key in self.data_keys:
|
||||
if self.data_keys[key] == MandatoryTraitKey:
|
||||
raise TraitException(
|
||||
"Trait-Key {key} cannot be deleted: It's a mandatory property "
|
||||
"with no default value to fall back to.")
|
||||
# set to default
|
||||
self._data[key] = self.data_keys[key]
|
||||
elif key in self._data:
|
||||
# an extra property. Delete as normal.
|
||||
del self._data[key]
|
||||
|
||||
def __repr__(self):
|
||||
|
|
@ -621,12 +666,12 @@ class Trait:
|
|||
return "{}({{{}}})".format(
|
||||
type(self).__name__,
|
||||
", ".join(
|
||||
["'{}': {!r}".format(k, self._data[k]) for k in self._keys if k in self._data]
|
||||
["'{}': {!r}".format(k, self._data[k]) for k in self.data_keys if k in self._data]
|
||||
),
|
||||
)
|
||||
|
||||
def __str__(self):
|
||||
return f"<Trait {self.name}>"
|
||||
return f"<Trait {self.name}: {self._data['value']}>"
|
||||
|
||||
# access properties
|
||||
|
||||
|
|
@ -637,13 +682,23 @@ class Trait:
|
|||
|
||||
key = name
|
||||
|
||||
@property
|
||||
def value(self):
|
||||
"""Store a value"""
|
||||
return self._data["value"]
|
||||
|
||||
@value.setter
|
||||
def value(self, value):
|
||||
"""Get value"""
|
||||
self._data["value"] = value
|
||||
|
||||
|
||||
@total_ordering
|
||||
class NumericTrait(Trait):
|
||||
"""
|
||||
Base trait for all Traits based on numbers. This implements
|
||||
number-comparisons, limits etc. It also features a "modifier" to the value,
|
||||
since this is a common use.
|
||||
number-comparisons, limits etc. It works on the 'base' property since this
|
||||
makes more sense for child classes.
|
||||
|
||||
"""
|
||||
|
||||
|
|
@ -652,6 +707,8 @@ class NumericTrait(Trait):
|
|||
data_keys = {
|
||||
"base": 0
|
||||
}
|
||||
def __str__(self):
|
||||
return f"<Trait {self.name}: {self._data['base']}>"
|
||||
|
||||
# Numeric operations
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue