diff --git a/evennia/scripts/taskhandler.py b/evennia/scripts/taskhandler.py index 6233b2f5d0..feb8a48dc1 100644 --- a/evennia/scripts/taskhandler.py +++ b/evennia/scripts/taskhandler.py @@ -52,6 +52,24 @@ class TaskHandlerTask: self.task_id = task_id self.deferred = TASK_HANDLER.get_deferred(task_id) + def _is_valid(self): + """Check if this task reference still points to the original task. + + A task reference becomes invalid when the original task is completed or + removed and its ID is reassigned to a new, unrelated task. This prevents + accidentally operating on the wrong task. + + Returns: + bool: True if this reference still points to its original task. + + """ + # task data is stored as (comp_time, callback, args, kwargs, persistent, deferred); + # compare the deferred (index 5) by identity to detect ID reuse + task_data = TASK_HANDLER.tasks.get(self.task_id) + if task_data is None: + return False + return task_data[5] is self.deferred + def get_deferred(self): """Return the instance of the deferred the task id is using. @@ -60,6 +78,8 @@ class TaskHandlerTask: None is returned if there is no deferred affiliated with this id. """ + if not self._is_valid(): + return None return TASK_HANDLER.get_deferred(self.task_id) def pause(self): @@ -111,6 +131,8 @@ class TaskHandlerTask: handler. Otherwise it will be the return of the task's callback. """ + if not self._is_valid(): + return False return TASK_HANDLER.do_task(self.task_id) def call(self): @@ -125,19 +147,21 @@ class TaskHandlerTask: handler. Otherwise it will be the return of the task's callback. """ + if not self._is_valid(): + return False return TASK_HANDLER.call_task(self.task_id) def remove(self): """Remove a task without executing it. Deletes the instance of the task's deferred. - Args: - task_id (int): an existing task ID. - Returns: bool: True if the removal completed successfully. + False if the task reference is no longer valid. """ + if not self._is_valid(): + return False return TASK_HANDLER.remove(self.task_id) def cancel(self): @@ -149,6 +173,8 @@ class TaskHandlerTask: False if the cancel did not complete successfully. """ + if not self._is_valid(): + return False return TASK_HANDLER.cancel(self.task_id) def active(self): @@ -159,6 +185,8 @@ class TaskHandlerTask: it is not (has been called) or if the task does not exist. """ + if not self._is_valid(): + return False return TASK_HANDLER.active(self.task_id) @property @@ -190,6 +218,8 @@ class TaskHandlerTask: bool: True the task exists False if it does not. """ + if not self._is_valid(): + return False return TASK_HANDLER.exists(self.task_id) def get_id(self): diff --git a/evennia/scripts/tests.py b/evennia/scripts/tests.py index 5e0f23c9bd..a2d31c18dc 100644 --- a/evennia/scripts/tests.py +++ b/evennia/scripts/tests.py @@ -6,8 +6,6 @@ Unit tests for the scripts package from collections import defaultdict from unittest import TestCase, mock -from parameterized import parameterized - from evennia import DefaultScript from evennia.objects.objects import DefaultObject from evennia.scripts.manager import ScriptDBManager @@ -15,10 +13,10 @@ from evennia.scripts.models import ObjectDoesNotExist, ScriptDB from evennia.scripts.monitorhandler import MonitorHandler from evennia.scripts.ondemandhandler import OnDemandHandler, OnDemandTask from evennia.scripts.scripts import DoNothing, ExtendedLoopingCall +from evennia.scripts.taskhandler import TASK_HANDLER from evennia.scripts.tickerhandler import TickerHandler from evennia.typeclasses.attributes import AttributeProperty from evennia.utils.create import create_script -from evennia.utils.dbserialize import dbserialize from evennia.utils.test_resources import BaseEvenniaTest, EvenniaTest @@ -382,6 +380,71 @@ class TestMonitorHandler(TestCase): self.assertEqual(self.handler.monitors[index][name], {}) +class TestTaskHandlerTask(TestCase): + """Test that TaskHandlerTask correctly handles stale references when task IDs are reused.""" + + def setUp(self): + from twisted.internet import task as twisted_task + + TASK_HANDLER.clock = twisted_task.Clock() + TASK_HANDLER.clear() + + def tearDown(self): + TASK_HANDLER.clear() + + def test_stale_reference_after_id_reuse(self): + """A stale TaskHandlerTask must not operate on a new task that reused its ID.""" + callback1 = mock.Mock(return_value="result1") + callback2 = mock.Mock(return_value="result2") + + # Create first task (gets ID 1) + task1 = TASK_HANDLER.add(5, callback1) + task1_id = task1.get_id() + + # Complete and remove the first task so its ID is freed + TASK_HANDLER.clock.advance(5) + + # Create second task - should reuse ID 1 + task2 = TASK_HANDLER.add(5, callback2) + self.assertEqual(task2.get_id(), task1_id) + + # The stale reference (task1) must not affect the new task (task2) + self.assertFalse(task1.exists()) + self.assertFalse(task1.active()) + self.assertFalse(task1.cancel()) + self.assertFalse(task1.remove()) + self.assertFalse(task1.do_task()) + self.assertFalse(task1.call()) + self.assertIsNone(task1.get_deferred()) + + # The new task must still be intact + self.assertTrue(task2.exists()) + self.assertTrue(task2.active()) + + def test_valid_reference_works_normally(self): + """A valid TaskHandlerTask should work as expected.""" + callback = mock.Mock(return_value="result") + task = TASK_HANDLER.add(5, callback) + + self.assertTrue(task.exists()) + self.assertTrue(task.active()) + self.assertIsNotNone(task.get_deferred()) + + result = task.call() + self.assertEqual(result, "result") + callback.assert_called_once() + + def test_is_valid_after_remove(self): + """After removing a task, its TaskHandlerTask reference should be invalid.""" + callback = mock.Mock() + task = TASK_HANDLER.add(5, callback) + + task.remove() + self.assertFalse(task.exists()) + self.assertFalse(task.active()) + self.assertFalse(task.cancel()) + + class TestOnDemandTask(EvenniaTest): """ Test the OnDemandTask class.