diff --git a/evennia/scripts/taskhandler.py b/evennia/scripts/taskhandler.py index 6233b2f5d0..0a3e6174a2 100644 --- a/evennia/scripts/taskhandler.py +++ b/evennia/scripts/taskhandler.py @@ -52,6 +52,24 @@ class TaskHandlerTask: self.task_id = task_id self.deferred = TASK_HANDLER.get_deferred(task_id) + def _is_valid(self): + """Check if this task reference still points to the original task. + + A task reference becomes invalid when the original task is completed or + removed and its ID is reassigned to a new, unrelated task. This prevents + accidentally operating on the wrong task. + + Returns: + bool: True if this reference still points to its original task. + + """ + # task data is stored as (comp_time, callback, args, kwargs, persistent, deferred); + # compare the deferred (index 5) by identity to detect ID reuse + task_data = TASK_HANDLER.tasks.get(self.task_id) + if task_data is None: + return False + return task_data[5] is self.deferred + def get_deferred(self): """Return the instance of the deferred the task id is using. @@ -60,6 +78,8 @@ class TaskHandlerTask: None is returned if there is no deferred affiliated with this id. """ + if not self._is_valid(): + return None return TASK_HANDLER.get_deferred(self.task_id) def pause(self): @@ -111,6 +131,8 @@ class TaskHandlerTask: handler. Otherwise it will be the return of the task's callback. """ + if not self._is_valid(): + return False return TASK_HANDLER.do_task(self.task_id) def call(self): @@ -125,19 +147,21 @@ class TaskHandlerTask: handler. Otherwise it will be the return of the task's callback. """ + if not self._is_valid(): + return False return TASK_HANDLER.call_task(self.task_id) def remove(self): """Remove a task without executing it. Deletes the instance of the task's deferred. - Args: - task_id (int): an existing task ID. - Returns: bool: True if the removal completed successfully. + False if the task reference is no longer valid. """ + if not self._is_valid(): + return False return TASK_HANDLER.remove(self.task_id) def cancel(self): @@ -149,6 +173,8 @@ class TaskHandlerTask: False if the cancel did not complete successfully. """ + if not self._is_valid(): + return False return TASK_HANDLER.cancel(self.task_id) def active(self): @@ -159,6 +185,8 @@ class TaskHandlerTask: it is not (has been called) or if the task does not exist. """ + if not self._is_valid(): + return False return TASK_HANDLER.active(self.task_id) @property @@ -190,6 +218,8 @@ class TaskHandlerTask: bool: True the task exists False if it does not. """ + if not self._is_valid(): + return False return TASK_HANDLER.exists(self.task_id) def get_id(self): @@ -273,16 +303,23 @@ class TaskHandler: """ clean_ids = [] + # if a now time is provided use it (intended for unit testing) + now = self._now if self._now else datetime.now() for task_id, (date, callback, args, kwargs, persistent, _) in self.tasks.items(): if not self.active(task_id): stale_date = date + timedelta(seconds=self.stale_timeout) - # if a now time is provided use it (intended for unit testing) - now = self._now if self._now else datetime.now() # the task was canceled more than stale_timeout seconds ago if now > stale_date: clean_ids.append(task_id) + needs_save = False for task_id in clean_ids: - self.remove(task_id) + self.cancel(task_id) + del self.tasks[task_id] + if task_id in self.to_save: + del self.to_save[task_id] + needs_save = True + if needs_save: + self.save() return True def save(self): @@ -351,9 +388,8 @@ class TaskHandler: delta = timedelta(seconds=timedelay) comp_time = now + delta # get an open task id - used_ids = list(self.tasks.keys()) task_id = 1 - while task_id in used_ids: + while task_id in self.tasks: task_id += 1 # record the task to the tasks dictionary @@ -395,19 +431,13 @@ class TaskHandler: self.tasks[task_id] = (comp_time, callback, args, kwargs, persistent, None) # defer the task - callback = self.do_task - args = [task_id] - kwargs = {} - d = deferLater(self.clock, timedelay, callback, *args, **kwargs) + d = deferLater(self.clock, timedelay, self.do_task, task_id) d.addErrback(handle_error) # some tasks may complete before the deferred can be added if task_id in self.tasks: - task = self.tasks.get(task_id) - task = list(task) - task[4] = persistent - task[5] = d - self.tasks[task_id] = task + comp_time, cb, args, kwargs, _, _ = self.tasks[task_id] + self.tasks[task_id] = (comp_time, cb, args, kwargs, persistent, d) else: # the task already completed return False if self.stale_timeout > 0: @@ -426,10 +456,7 @@ class TaskHandler: bool: True the task exists False if it does not. """ - if task_id in self.tasks: - return True - else: - return False + return task_id in self.tasks def active(self, task_id): """ @@ -444,7 +471,6 @@ class TaskHandler: """ if task_id in self.tasks: - # if the task has not been run, cancel it deferred = self.get_deferred(task_id) return not (deferred is not None and deferred.called) else: @@ -464,7 +490,6 @@ class TaskHandler: """ if task_id in self.tasks: - # if the task has not been run, cancel it d = self.get_deferred(task_id) if d is not None: # it is remotely possible for a task to not have a deferred if d.called: @@ -489,19 +514,14 @@ class TaskHandler: bool: True if the removal completed successfully. """ - d = None # delete the task from the tasks dictionary if task_id in self.tasks: - # if the task has not been run, cancel it self.cancel(task_id) - del self.tasks[task_id] # delete the task from the tasks dictionary + del self.tasks[task_id] # remove the task from the persistent dictionary and ServerConfig if task_id in self.to_save: del self.to_save[task_id] - self.save() # remove from ServerConfig.objects - # delete the instance of the deferred - if d: - del d + self.save() return True def clear(self, save=True, cancel=True): diff --git a/evennia/scripts/tests.py b/evennia/scripts/tests.py index 5e0f23c9bd..a2d31c18dc 100644 --- a/evennia/scripts/tests.py +++ b/evennia/scripts/tests.py @@ -6,8 +6,6 @@ Unit tests for the scripts package from collections import defaultdict from unittest import TestCase, mock -from parameterized import parameterized - from evennia import DefaultScript from evennia.objects.objects import DefaultObject from evennia.scripts.manager import ScriptDBManager @@ -15,10 +13,10 @@ from evennia.scripts.models import ObjectDoesNotExist, ScriptDB from evennia.scripts.monitorhandler import MonitorHandler from evennia.scripts.ondemandhandler import OnDemandHandler, OnDemandTask from evennia.scripts.scripts import DoNothing, ExtendedLoopingCall +from evennia.scripts.taskhandler import TASK_HANDLER from evennia.scripts.tickerhandler import TickerHandler from evennia.typeclasses.attributes import AttributeProperty from evennia.utils.create import create_script -from evennia.utils.dbserialize import dbserialize from evennia.utils.test_resources import BaseEvenniaTest, EvenniaTest @@ -382,6 +380,71 @@ class TestMonitorHandler(TestCase): self.assertEqual(self.handler.monitors[index][name], {}) +class TestTaskHandlerTask(TestCase): + """Test that TaskHandlerTask correctly handles stale references when task IDs are reused.""" + + def setUp(self): + from twisted.internet import task as twisted_task + + TASK_HANDLER.clock = twisted_task.Clock() + TASK_HANDLER.clear() + + def tearDown(self): + TASK_HANDLER.clear() + + def test_stale_reference_after_id_reuse(self): + """A stale TaskHandlerTask must not operate on a new task that reused its ID.""" + callback1 = mock.Mock(return_value="result1") + callback2 = mock.Mock(return_value="result2") + + # Create first task (gets ID 1) + task1 = TASK_HANDLER.add(5, callback1) + task1_id = task1.get_id() + + # Complete and remove the first task so its ID is freed + TASK_HANDLER.clock.advance(5) + + # Create second task - should reuse ID 1 + task2 = TASK_HANDLER.add(5, callback2) + self.assertEqual(task2.get_id(), task1_id) + + # The stale reference (task1) must not affect the new task (task2) + self.assertFalse(task1.exists()) + self.assertFalse(task1.active()) + self.assertFalse(task1.cancel()) + self.assertFalse(task1.remove()) + self.assertFalse(task1.do_task()) + self.assertFalse(task1.call()) + self.assertIsNone(task1.get_deferred()) + + # The new task must still be intact + self.assertTrue(task2.exists()) + self.assertTrue(task2.active()) + + def test_valid_reference_works_normally(self): + """A valid TaskHandlerTask should work as expected.""" + callback = mock.Mock(return_value="result") + task = TASK_HANDLER.add(5, callback) + + self.assertTrue(task.exists()) + self.assertTrue(task.active()) + self.assertIsNotNone(task.get_deferred()) + + result = task.call() + self.assertEqual(result, "result") + callback.assert_called_once() + + def test_is_valid_after_remove(self): + """After removing a task, its TaskHandlerTask reference should be invalid.""" + callback = mock.Mock() + task = TASK_HANDLER.add(5, callback) + + task.remove() + self.assertFalse(task.exists()) + self.assertFalse(task.active()) + self.assertFalse(task.cancel()) + + class TestOnDemandTask(EvenniaTest): """ Test the OnDemandTask class.