diff --git a/evennia/scripts/tests.py b/evennia/scripts/tests.py index b54f465044..bd1d245a25 100644 --- a/evennia/scripts/tests.py +++ b/evennia/scripts/tests.py @@ -3,21 +3,20 @@ Unit tests for the scripts package """ -from unittest import TestCase, mock from collections import defaultdict - -from parameterized import parameterized +from unittest import TestCase, mock from evennia import DefaultScript from evennia.objects.objects import DefaultObject -from evennia.scripts.models import ObjectDoesNotExist, ScriptDB -from evennia.scripts.scripts import DoNothing, ExtendedLoopingCall -from evennia.utils.create import create_script -from evennia.utils.test_resources import BaseEvenniaTest -from evennia.scripts.tickerhandler import TickerHandler -from evennia.scripts.monitorhandler import MonitorHandler from evennia.scripts.manager import ScriptDBManager +from evennia.scripts.models import ObjectDoesNotExist, ScriptDB +from evennia.scripts.monitorhandler import MonitorHandler +from evennia.scripts.scripts import DoNothing, ExtendedLoopingCall +from evennia.scripts.tickerhandler import TickerHandler +from evennia.utils.create import create_script from evennia.utils.dbserialize import dbserialize +from evennia.utils.test_resources import BaseEvenniaTest +from parameterized import parameterized class TestScript(BaseEvenniaTest): @@ -29,34 +28,38 @@ class TestScript(BaseEvenniaTest): self.assertFalse(errors, errors) mockinit.assert_called() + class TestTickerHandler(TestCase): - """ Test the TickerHandler class """ + """Test the TickerHandler class""" def test_store_key_raises_RunTimeError(self): - """ Test _store_key method raises RuntimeError for interval < 1 """ + """Test _store_key method raises RuntimeError for interval < 1""" with self.assertRaises(RuntimeError): - th=TickerHandler() + th = TickerHandler() th._store_key(None, None, 0, None) def test_remove_raises_RunTimeError(self): - """ Test remove method raises RuntimeError for catching old ordering of arguments """ - with self.assertRaises(RuntimeError): - th=TickerHandler() + """Test remove method raises RuntimeError for catching old ordering of arguments""" + with self.assertRaises(RuntimeError): + th = TickerHandler() th.remove(callback=1) + class TestScriptDBManager(TestCase): - """ Test the ScriptDBManger class """ + """Test the ScriptDBManger class""" def test_not_obj_return_empty_list(self): - """ Test get_all_scripts_on_obj returns empty list for falsy object """ + """Test get_all_scripts_on_obj returns empty list for falsy object""" manager_obj = ScriptDBManager() returned_list = manager_obj.get_all_scripts_on_obj(False) self.assertEqual(returned_list, []) + class TestingListIntervalScript(DefaultScript): """ A script that does nothing. Used to test listing of script with nonzero intervals. """ + def at_script_creation(self): """ Setup the script @@ -66,11 +69,13 @@ class TestingListIntervalScript(DefaultScript): self.interval = 1 self.repeats = 1 + class TestScriptHandler(BaseEvenniaTest): """ Test the ScriptHandler class. """ + def setUp(self): self.obj, self.errors = DefaultObject.create("test_object") @@ -82,7 +87,7 @@ class TestScriptHandler(BaseEvenniaTest): self.obj.scripts.add(TestingListIntervalScript) self.num = self.obj.scripts.start(self.obj.scripts.all()[0].key) self.assertTrue(self.num == 1) - + def test_list_script_intervals(self): "Checks that Scripthandler __str__ function lists script intervals correctly" self.obj.scripts.add(TestingListIntervalScript) @@ -90,6 +95,13 @@ class TestScriptHandler(BaseEvenniaTest): self.assertTrue("None/1" in self.str) self.assertTrue("1 repeats" in self.str) + def test_get_script(self): + "Checks that Scripthandler get function returns correct script" + self.obj.scripts.add(TestingListIntervalScript) + script = self.obj.scripts.get("interval_test") + self.assertTrue(bool(script)) + + class TestScriptDB(TestCase): "Check the singleton/static ScriptDB object works correctly" @@ -161,14 +173,14 @@ class TestExtendedLoopingCall(TestCase): loopcall._scheduleFrom.assert_called_with(121) def test_start_invalid_interval(self): - """ Test the .start method with interval less than zero """ + """Test the .start method with interval less than zero""" with self.assertRaises(ValueError): callback = mock.MagicMock() loopcall = ExtendedLoopingCall(callback) loopcall.start(-1, now=True, start_delay=None, count_start=1) def test__call__when_delay(self): - """ Test __call__ modifies start_delay and starttime if start_delay was previously set """ + """Test __call__ modifies start_delay and starttime if start_delay was previously set""" callback = mock.MagicMock() loopcall = ExtendedLoopingCall(callback) loopcall.clock.seconds = mock.MagicMock(return_value=1) @@ -176,12 +188,12 @@ class TestExtendedLoopingCall(TestCase): loopcall.starttime = 0 loopcall() - + self.assertEqual(loopcall.start_delay, None) self.assertEqual(loopcall.starttime, 1) def test_force_repeat(self): - """ Test forcing script to run that is scheduled to run in the future """ + """Test forcing script to run that is scheduled to run in the future""" callback = mock.MagicMock() loopcall = ExtendedLoopingCall(callback) loopcall.clock.seconds = mock.MagicMock(return_value=0) @@ -192,10 +204,12 @@ class TestExtendedLoopingCall(TestCase): callback.assert_called_once() + def dummy_func(): - """ Dummy function used as callback parameter """ + """Dummy function used as callback parameter""" return 0 + class TestMonitorHandler(TestCase): """ Test the MonitorHandler class. @@ -220,13 +234,13 @@ class TestMonitorHandler(TestCase): def test_remove(self): """Tests that removing an object from the monitor handler works correctly""" obj = mock.Mock() - fieldname = 'db_remove' + fieldname = "db_remove" callback = dummy_func - idstring = 'test_remove' + idstring = "test_remove" """Add an object to the monitor handler and then remove it""" - self.handler.add(obj,fieldname,callback,idstring=idstring) - self.handler.remove(obj,fieldname,idstring=idstring) + self.handler.add(obj, fieldname, callback, idstring=idstring) + self.handler.remove(obj, fieldname, idstring=idstring) self.assertEquals(self.handler.monitors[obj][fieldname], {}) def test_add_with_invalid_function(self): @@ -234,25 +248,29 @@ class TestMonitorHandler(TestCase): """Tests that add method rejects objects where callback is not a function""" fieldname = "db_key" callback = "not_a_function" - + self.handler.add(obj, fieldname, callback) self.assertNotIn(fieldname, self.handler.monitors[obj]) def test_all(self): """Tests that all method correctly returns information about added objects""" - obj = [mock.Mock(),mock.Mock()] - fieldname = ["db_all1","db_all2"] + obj = [mock.Mock(), mock.Mock()] + fieldname = ["db_all1", "db_all2"] callback = dummy_func - idstring = ["test_all1","test_all2"] + idstring = ["test_all1", "test_all2"] self.handler.add(obj[0], fieldname[0], callback, idstring=idstring[0]) - self.handler.add(obj[1], fieldname[1], callback, idstring=idstring[1],persistent=True) - + self.handler.add(obj[1], fieldname[1], callback, idstring=idstring[1], persistent=True) + output = self.handler.all() - self.assertEquals(output, - [(obj[0], fieldname[0], idstring[0], False, {}), - (obj[1], fieldname[1], idstring[1], True, {})]) - + self.assertEquals( + output, + [ + (obj[0], fieldname[0], idstring[0], False, {}), + (obj[1], fieldname[1], idstring[1], True, {}), + ], + ) + def test_clear(self): """Tests that the clear function correctly clears the monitor handler""" obj = mock.Mock() @@ -277,7 +295,7 @@ class TestMonitorHandler(TestCase): category = "testattribute" """Add attribute to handler and assert that it has been added""" - self.handler.add(obj, fieldname, callback, idstring=idstring,category=category) + self.handler.add(obj, fieldname, callback, idstring=idstring, category=category) index = obj.attributes.get(fieldname, return_obj=True) name = "db_value[testattribute]" @@ -287,5 +305,5 @@ class TestMonitorHandler(TestCase): self.assertEqual(self.handler.monitors[index][name][idstring], (callback, False, {})) """Remove attribute from the handler and assert that it is gone""" - self.handler.remove(obj,fieldname,idstring=idstring,category=category) + self.handler.remove(obj, fieldname, idstring=idstring, category=category) self.assertEquals(self.handler.monitors[index][name], {})