fix: Add validity check to TaskHandlerTask

This commit is contained in:
Jake 2026-03-04 21:45:59 -08:00
parent 3761a7cb21
commit 5a5304f38f
2 changed files with 99 additions and 6 deletions

View file

@ -52,6 +52,24 @@ class TaskHandlerTask:
self.task_id = task_id
self.deferred = TASK_HANDLER.get_deferred(task_id)
def _is_valid(self):
"""Check if this task reference still points to the original task.
A task reference becomes invalid when the original task is completed or
removed and its ID is reassigned to a new, unrelated task. This prevents
accidentally operating on the wrong task.
Returns:
bool: True if this reference still points to its original task.
"""
# task data is stored as (comp_time, callback, args, kwargs, persistent, deferred);
# compare the deferred (index 5) by identity to detect ID reuse
task_data = TASK_HANDLER.tasks.get(self.task_id)
if task_data is None:
return False
return task_data[5] is self.deferred
def get_deferred(self):
"""Return the instance of the deferred the task id is using.
@ -60,6 +78,8 @@ class TaskHandlerTask:
None is returned if there is no deferred affiliated with this id.
"""
if not self._is_valid():
return None
return TASK_HANDLER.get_deferred(self.task_id)
def pause(self):
@ -111,6 +131,8 @@ class TaskHandlerTask:
handler. Otherwise it will be the return of the task's callback.
"""
if not self._is_valid():
return False
return TASK_HANDLER.do_task(self.task_id)
def call(self):
@ -125,19 +147,21 @@ class TaskHandlerTask:
handler. Otherwise it will be the return of the task's callback.
"""
if not self._is_valid():
return False
return TASK_HANDLER.call_task(self.task_id)
def remove(self):
"""Remove a task without executing it.
Deletes the instance of the task's deferred.
Args:
task_id (int): an existing task ID.
Returns:
bool: True if the removal completed successfully.
False if the task reference is no longer valid.
"""
if not self._is_valid():
return False
return TASK_HANDLER.remove(self.task_id)
def cancel(self):
@ -149,6 +173,8 @@ class TaskHandlerTask:
False if the cancel did not complete successfully.
"""
if not self._is_valid():
return False
return TASK_HANDLER.cancel(self.task_id)
def active(self):
@ -159,6 +185,8 @@ class TaskHandlerTask:
it is not (has been called) or if the task does not exist.
"""
if not self._is_valid():
return False
return TASK_HANDLER.active(self.task_id)
@property
@ -190,6 +218,8 @@ class TaskHandlerTask:
bool: True the task exists False if it does not.
"""
if not self._is_valid():
return False
return TASK_HANDLER.exists(self.task_id)
def get_id(self):

View file

@ -6,8 +6,6 @@ Unit tests for the scripts package
from collections import defaultdict
from unittest import TestCase, mock
from parameterized import parameterized
from evennia import DefaultScript
from evennia.objects.objects import DefaultObject
from evennia.scripts.manager import ScriptDBManager
@ -15,10 +13,10 @@ from evennia.scripts.models import ObjectDoesNotExist, ScriptDB
from evennia.scripts.monitorhandler import MonitorHandler
from evennia.scripts.ondemandhandler import OnDemandHandler, OnDemandTask
from evennia.scripts.scripts import DoNothing, ExtendedLoopingCall
from evennia.scripts.taskhandler import TASK_HANDLER
from evennia.scripts.tickerhandler import TickerHandler
from evennia.typeclasses.attributes import AttributeProperty
from evennia.utils.create import create_script
from evennia.utils.dbserialize import dbserialize
from evennia.utils.test_resources import BaseEvenniaTest, EvenniaTest
@ -382,6 +380,71 @@ class TestMonitorHandler(TestCase):
self.assertEqual(self.handler.monitors[index][name], {})
class TestTaskHandlerTask(TestCase):
"""Test that TaskHandlerTask correctly handles stale references when task IDs are reused."""
def setUp(self):
from twisted.internet import task as twisted_task
TASK_HANDLER.clock = twisted_task.Clock()
TASK_HANDLER.clear()
def tearDown(self):
TASK_HANDLER.clear()
def test_stale_reference_after_id_reuse(self):
"""A stale TaskHandlerTask must not operate on a new task that reused its ID."""
callback1 = mock.Mock(return_value="result1")
callback2 = mock.Mock(return_value="result2")
# Create first task (gets ID 1)
task1 = TASK_HANDLER.add(5, callback1)
task1_id = task1.get_id()
# Complete and remove the first task so its ID is freed
TASK_HANDLER.clock.advance(5)
# Create second task - should reuse ID 1
task2 = TASK_HANDLER.add(5, callback2)
self.assertEqual(task2.get_id(), task1_id)
# The stale reference (task1) must not affect the new task (task2)
self.assertFalse(task1.exists())
self.assertFalse(task1.active())
self.assertFalse(task1.cancel())
self.assertFalse(task1.remove())
self.assertFalse(task1.do_task())
self.assertFalse(task1.call())
self.assertIsNone(task1.get_deferred())
# The new task must still be intact
self.assertTrue(task2.exists())
self.assertTrue(task2.active())
def test_valid_reference_works_normally(self):
"""A valid TaskHandlerTask should work as expected."""
callback = mock.Mock(return_value="result")
task = TASK_HANDLER.add(5, callback)
self.assertTrue(task.exists())
self.assertTrue(task.active())
self.assertIsNotNone(task.get_deferred())
result = task.call()
self.assertEqual(result, "result")
callback.assert_called_once()
def test_is_valid_after_remove(self):
"""After removing a task, its TaskHandlerTask reference should be invalid."""
callback = mock.Mock()
task = TASK_HANDLER.add(5, callback)
task.remove()
self.assertFalse(task.exists())
self.assertFalse(task.active())
self.assertFalse(task.cancel())
class TestOnDemandTask(EvenniaTest):
"""
Test the OnDemandTask class.