This commit is contained in:
Jake 2026-03-05 06:15:00 +00:00 committed by GitHub
commit 49a0bcd4cc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 116 additions and 33 deletions

View file

@ -52,6 +52,24 @@ class TaskHandlerTask:
self.task_id = task_id
self.deferred = TASK_HANDLER.get_deferred(task_id)
def _is_valid(self):
"""Check if this task reference still points to the original task.
A task reference becomes invalid when the original task is completed or
removed and its ID is reassigned to a new, unrelated task. This prevents
accidentally operating on the wrong task.
Returns:
bool: True if this reference still points to its original task.
"""
# task data is stored as (comp_time, callback, args, kwargs, persistent, deferred);
# compare the deferred (index 5) by identity to detect ID reuse
task_data = TASK_HANDLER.tasks.get(self.task_id)
if task_data is None:
return False
return task_data[5] is self.deferred
def get_deferred(self):
"""Return the instance of the deferred the task id is using.
@ -60,6 +78,8 @@ class TaskHandlerTask:
None is returned if there is no deferred affiliated with this id.
"""
if not self._is_valid():
return None
return TASK_HANDLER.get_deferred(self.task_id)
def pause(self):
@ -111,6 +131,8 @@ class TaskHandlerTask:
handler. Otherwise it will be the return of the task's callback.
"""
if not self._is_valid():
return False
return TASK_HANDLER.do_task(self.task_id)
def call(self):
@ -125,19 +147,21 @@ class TaskHandlerTask:
handler. Otherwise it will be the return of the task's callback.
"""
if not self._is_valid():
return False
return TASK_HANDLER.call_task(self.task_id)
def remove(self):
"""Remove a task without executing it.
Deletes the instance of the task's deferred.
Args:
task_id (int): an existing task ID.
Returns:
bool: True if the removal completed successfully.
False if the task reference is no longer valid.
"""
if not self._is_valid():
return False
return TASK_HANDLER.remove(self.task_id)
def cancel(self):
@ -149,6 +173,8 @@ class TaskHandlerTask:
False if the cancel did not complete successfully.
"""
if not self._is_valid():
return False
return TASK_HANDLER.cancel(self.task_id)
def active(self):
@ -159,6 +185,8 @@ class TaskHandlerTask:
it is not (has been called) or if the task does not exist.
"""
if not self._is_valid():
return False
return TASK_HANDLER.active(self.task_id)
@property
@ -190,6 +218,8 @@ class TaskHandlerTask:
bool: True the task exists False if it does not.
"""
if not self._is_valid():
return False
return TASK_HANDLER.exists(self.task_id)
def get_id(self):
@ -273,16 +303,23 @@ class TaskHandler:
"""
clean_ids = []
# if a now time is provided use it (intended for unit testing)
now = self._now if self._now else datetime.now()
for task_id, (date, callback, args, kwargs, persistent, _) in self.tasks.items():
if not self.active(task_id):
stale_date = date + timedelta(seconds=self.stale_timeout)
# if a now time is provided use it (intended for unit testing)
now = self._now if self._now else datetime.now()
# the task was canceled more than stale_timeout seconds ago
if now > stale_date:
clean_ids.append(task_id)
needs_save = False
for task_id in clean_ids:
self.remove(task_id)
self.cancel(task_id)
del self.tasks[task_id]
if task_id in self.to_save:
del self.to_save[task_id]
needs_save = True
if needs_save:
self.save()
return True
def save(self):
@ -351,9 +388,8 @@ class TaskHandler:
delta = timedelta(seconds=timedelay)
comp_time = now + delta
# get an open task id
used_ids = list(self.tasks.keys())
task_id = 1
while task_id in used_ids:
while task_id in self.tasks:
task_id += 1
# record the task to the tasks dictionary
@ -395,19 +431,13 @@ class TaskHandler:
self.tasks[task_id] = (comp_time, callback, args, kwargs, persistent, None)
# defer the task
callback = self.do_task
args = [task_id]
kwargs = {}
d = deferLater(self.clock, timedelay, callback, *args, **kwargs)
d = deferLater(self.clock, timedelay, self.do_task, task_id)
d.addErrback(handle_error)
# some tasks may complete before the deferred can be added
if task_id in self.tasks:
task = self.tasks.get(task_id)
task = list(task)
task[4] = persistent
task[5] = d
self.tasks[task_id] = task
comp_time, cb, args, kwargs, _, _ = self.tasks[task_id]
self.tasks[task_id] = (comp_time, cb, args, kwargs, persistent, d)
else: # the task already completed
return False
if self.stale_timeout > 0:
@ -426,10 +456,7 @@ class TaskHandler:
bool: True the task exists False if it does not.
"""
if task_id in self.tasks:
return True
else:
return False
return task_id in self.tasks
def active(self, task_id):
"""
@ -444,7 +471,6 @@ class TaskHandler:
"""
if task_id in self.tasks:
# if the task has not been run, cancel it
deferred = self.get_deferred(task_id)
return not (deferred is not None and deferred.called)
else:
@ -464,7 +490,6 @@ class TaskHandler:
"""
if task_id in self.tasks:
# if the task has not been run, cancel it
d = self.get_deferred(task_id)
if d is not None: # it is remotely possible for a task to not have a deferred
if d.called:
@ -489,19 +514,14 @@ class TaskHandler:
bool: True if the removal completed successfully.
"""
d = None
# delete the task from the tasks dictionary
if task_id in self.tasks:
# if the task has not been run, cancel it
self.cancel(task_id)
del self.tasks[task_id] # delete the task from the tasks dictionary
del self.tasks[task_id]
# remove the task from the persistent dictionary and ServerConfig
if task_id in self.to_save:
del self.to_save[task_id]
self.save() # remove from ServerConfig.objects
# delete the instance of the deferred
if d:
del d
self.save()
return True
def clear(self, save=True, cancel=True):

View file

@ -6,8 +6,6 @@ Unit tests for the scripts package
from collections import defaultdict
from unittest import TestCase, mock
from parameterized import parameterized
from evennia import DefaultScript
from evennia.objects.objects import DefaultObject
from evennia.scripts.manager import ScriptDBManager
@ -15,10 +13,10 @@ from evennia.scripts.models import ObjectDoesNotExist, ScriptDB
from evennia.scripts.monitorhandler import MonitorHandler
from evennia.scripts.ondemandhandler import OnDemandHandler, OnDemandTask
from evennia.scripts.scripts import DoNothing, ExtendedLoopingCall
from evennia.scripts.taskhandler import TASK_HANDLER
from evennia.scripts.tickerhandler import TickerHandler
from evennia.typeclasses.attributes import AttributeProperty
from evennia.utils.create import create_script
from evennia.utils.dbserialize import dbserialize
from evennia.utils.test_resources import BaseEvenniaTest, EvenniaTest
@ -382,6 +380,71 @@ class TestMonitorHandler(TestCase):
self.assertEqual(self.handler.monitors[index][name], {})
class TestTaskHandlerTask(TestCase):
"""Test that TaskHandlerTask correctly handles stale references when task IDs are reused."""
def setUp(self):
from twisted.internet import task as twisted_task
TASK_HANDLER.clock = twisted_task.Clock()
TASK_HANDLER.clear()
def tearDown(self):
TASK_HANDLER.clear()
def test_stale_reference_after_id_reuse(self):
"""A stale TaskHandlerTask must not operate on a new task that reused its ID."""
callback1 = mock.Mock(return_value="result1")
callback2 = mock.Mock(return_value="result2")
# Create first task (gets ID 1)
task1 = TASK_HANDLER.add(5, callback1)
task1_id = task1.get_id()
# Complete and remove the first task so its ID is freed
TASK_HANDLER.clock.advance(5)
# Create second task - should reuse ID 1
task2 = TASK_HANDLER.add(5, callback2)
self.assertEqual(task2.get_id(), task1_id)
# The stale reference (task1) must not affect the new task (task2)
self.assertFalse(task1.exists())
self.assertFalse(task1.active())
self.assertFalse(task1.cancel())
self.assertFalse(task1.remove())
self.assertFalse(task1.do_task())
self.assertFalse(task1.call())
self.assertIsNone(task1.get_deferred())
# The new task must still be intact
self.assertTrue(task2.exists())
self.assertTrue(task2.active())
def test_valid_reference_works_normally(self):
"""A valid TaskHandlerTask should work as expected."""
callback = mock.Mock(return_value="result")
task = TASK_HANDLER.add(5, callback)
self.assertTrue(task.exists())
self.assertTrue(task.active())
self.assertIsNotNone(task.get_deferred())
result = task.call()
self.assertEqual(result, "result")
callback.assert_called_once()
def test_is_valid_after_remove(self):
"""After removing a task, its TaskHandlerTask reference should be invalid."""
callback = mock.Mock()
task = TASK_HANDLER.add(5, callback)
task.remove()
self.assertFalse(task.exists())
self.assertFalse(task.active())
self.assertFalse(task.cancel())
class TestOnDemandTask(EvenniaTest):
"""
Test the OnDemandTask class.