mirror of
https://github.com/evennia/evennia.git
synced 2026-03-18 22:06:30 +01:00
Merge 53dbc57ba3 into 3761a7cb21
This commit is contained in:
commit
49a0bcd4cc
2 changed files with 116 additions and 33 deletions
|
|
@ -52,6 +52,24 @@ class TaskHandlerTask:
|
|||
self.task_id = task_id
|
||||
self.deferred = TASK_HANDLER.get_deferred(task_id)
|
||||
|
||||
def _is_valid(self):
|
||||
"""Check if this task reference still points to the original task.
|
||||
|
||||
A task reference becomes invalid when the original task is completed or
|
||||
removed and its ID is reassigned to a new, unrelated task. This prevents
|
||||
accidentally operating on the wrong task.
|
||||
|
||||
Returns:
|
||||
bool: True if this reference still points to its original task.
|
||||
|
||||
"""
|
||||
# task data is stored as (comp_time, callback, args, kwargs, persistent, deferred);
|
||||
# compare the deferred (index 5) by identity to detect ID reuse
|
||||
task_data = TASK_HANDLER.tasks.get(self.task_id)
|
||||
if task_data is None:
|
||||
return False
|
||||
return task_data[5] is self.deferred
|
||||
|
||||
def get_deferred(self):
|
||||
"""Return the instance of the deferred the task id is using.
|
||||
|
||||
|
|
@ -60,6 +78,8 @@ class TaskHandlerTask:
|
|||
None is returned if there is no deferred affiliated with this id.
|
||||
|
||||
"""
|
||||
if not self._is_valid():
|
||||
return None
|
||||
return TASK_HANDLER.get_deferred(self.task_id)
|
||||
|
||||
def pause(self):
|
||||
|
|
@ -111,6 +131,8 @@ class TaskHandlerTask:
|
|||
handler. Otherwise it will be the return of the task's callback.
|
||||
|
||||
"""
|
||||
if not self._is_valid():
|
||||
return False
|
||||
return TASK_HANDLER.do_task(self.task_id)
|
||||
|
||||
def call(self):
|
||||
|
|
@ -125,19 +147,21 @@ class TaskHandlerTask:
|
|||
handler. Otherwise it will be the return of the task's callback.
|
||||
|
||||
"""
|
||||
if not self._is_valid():
|
||||
return False
|
||||
return TASK_HANDLER.call_task(self.task_id)
|
||||
|
||||
def remove(self):
|
||||
"""Remove a task without executing it.
|
||||
Deletes the instance of the task's deferred.
|
||||
|
||||
Args:
|
||||
task_id (int): an existing task ID.
|
||||
|
||||
Returns:
|
||||
bool: True if the removal completed successfully.
|
||||
False if the task reference is no longer valid.
|
||||
|
||||
"""
|
||||
if not self._is_valid():
|
||||
return False
|
||||
return TASK_HANDLER.remove(self.task_id)
|
||||
|
||||
def cancel(self):
|
||||
|
|
@ -149,6 +173,8 @@ class TaskHandlerTask:
|
|||
False if the cancel did not complete successfully.
|
||||
|
||||
"""
|
||||
if not self._is_valid():
|
||||
return False
|
||||
return TASK_HANDLER.cancel(self.task_id)
|
||||
|
||||
def active(self):
|
||||
|
|
@ -159,6 +185,8 @@ class TaskHandlerTask:
|
|||
it is not (has been called) or if the task does not exist.
|
||||
|
||||
"""
|
||||
if not self._is_valid():
|
||||
return False
|
||||
return TASK_HANDLER.active(self.task_id)
|
||||
|
||||
@property
|
||||
|
|
@ -190,6 +218,8 @@ class TaskHandlerTask:
|
|||
bool: True the task exists False if it does not.
|
||||
|
||||
"""
|
||||
if not self._is_valid():
|
||||
return False
|
||||
return TASK_HANDLER.exists(self.task_id)
|
||||
|
||||
def get_id(self):
|
||||
|
|
@ -273,16 +303,23 @@ class TaskHandler:
|
|||
|
||||
"""
|
||||
clean_ids = []
|
||||
# if a now time is provided use it (intended for unit testing)
|
||||
now = self._now if self._now else datetime.now()
|
||||
for task_id, (date, callback, args, kwargs, persistent, _) in self.tasks.items():
|
||||
if not self.active(task_id):
|
||||
stale_date = date + timedelta(seconds=self.stale_timeout)
|
||||
# if a now time is provided use it (intended for unit testing)
|
||||
now = self._now if self._now else datetime.now()
|
||||
# the task was canceled more than stale_timeout seconds ago
|
||||
if now > stale_date:
|
||||
clean_ids.append(task_id)
|
||||
needs_save = False
|
||||
for task_id in clean_ids:
|
||||
self.remove(task_id)
|
||||
self.cancel(task_id)
|
||||
del self.tasks[task_id]
|
||||
if task_id in self.to_save:
|
||||
del self.to_save[task_id]
|
||||
needs_save = True
|
||||
if needs_save:
|
||||
self.save()
|
||||
return True
|
||||
|
||||
def save(self):
|
||||
|
|
@ -351,9 +388,8 @@ class TaskHandler:
|
|||
delta = timedelta(seconds=timedelay)
|
||||
comp_time = now + delta
|
||||
# get an open task id
|
||||
used_ids = list(self.tasks.keys())
|
||||
task_id = 1
|
||||
while task_id in used_ids:
|
||||
while task_id in self.tasks:
|
||||
task_id += 1
|
||||
|
||||
# record the task to the tasks dictionary
|
||||
|
|
@ -395,19 +431,13 @@ class TaskHandler:
|
|||
self.tasks[task_id] = (comp_time, callback, args, kwargs, persistent, None)
|
||||
|
||||
# defer the task
|
||||
callback = self.do_task
|
||||
args = [task_id]
|
||||
kwargs = {}
|
||||
d = deferLater(self.clock, timedelay, callback, *args, **kwargs)
|
||||
d = deferLater(self.clock, timedelay, self.do_task, task_id)
|
||||
d.addErrback(handle_error)
|
||||
|
||||
# some tasks may complete before the deferred can be added
|
||||
if task_id in self.tasks:
|
||||
task = self.tasks.get(task_id)
|
||||
task = list(task)
|
||||
task[4] = persistent
|
||||
task[5] = d
|
||||
self.tasks[task_id] = task
|
||||
comp_time, cb, args, kwargs, _, _ = self.tasks[task_id]
|
||||
self.tasks[task_id] = (comp_time, cb, args, kwargs, persistent, d)
|
||||
else: # the task already completed
|
||||
return False
|
||||
if self.stale_timeout > 0:
|
||||
|
|
@ -426,10 +456,7 @@ class TaskHandler:
|
|||
bool: True the task exists False if it does not.
|
||||
|
||||
"""
|
||||
if task_id in self.tasks:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
return task_id in self.tasks
|
||||
|
||||
def active(self, task_id):
|
||||
"""
|
||||
|
|
@ -444,7 +471,6 @@ class TaskHandler:
|
|||
|
||||
"""
|
||||
if task_id in self.tasks:
|
||||
# if the task has not been run, cancel it
|
||||
deferred = self.get_deferred(task_id)
|
||||
return not (deferred is not None and deferred.called)
|
||||
else:
|
||||
|
|
@ -464,7 +490,6 @@ class TaskHandler:
|
|||
|
||||
"""
|
||||
if task_id in self.tasks:
|
||||
# if the task has not been run, cancel it
|
||||
d = self.get_deferred(task_id)
|
||||
if d is not None: # it is remotely possible for a task to not have a deferred
|
||||
if d.called:
|
||||
|
|
@ -489,19 +514,14 @@ class TaskHandler:
|
|||
bool: True if the removal completed successfully.
|
||||
|
||||
"""
|
||||
d = None
|
||||
# delete the task from the tasks dictionary
|
||||
if task_id in self.tasks:
|
||||
# if the task has not been run, cancel it
|
||||
self.cancel(task_id)
|
||||
del self.tasks[task_id] # delete the task from the tasks dictionary
|
||||
del self.tasks[task_id]
|
||||
# remove the task from the persistent dictionary and ServerConfig
|
||||
if task_id in self.to_save:
|
||||
del self.to_save[task_id]
|
||||
self.save() # remove from ServerConfig.objects
|
||||
# delete the instance of the deferred
|
||||
if d:
|
||||
del d
|
||||
self.save()
|
||||
return True
|
||||
|
||||
def clear(self, save=True, cancel=True):
|
||||
|
|
|
|||
|
|
@ -6,8 +6,6 @@ Unit tests for the scripts package
|
|||
from collections import defaultdict
|
||||
from unittest import TestCase, mock
|
||||
|
||||
from parameterized import parameterized
|
||||
|
||||
from evennia import DefaultScript
|
||||
from evennia.objects.objects import DefaultObject
|
||||
from evennia.scripts.manager import ScriptDBManager
|
||||
|
|
@ -15,10 +13,10 @@ from evennia.scripts.models import ObjectDoesNotExist, ScriptDB
|
|||
from evennia.scripts.monitorhandler import MonitorHandler
|
||||
from evennia.scripts.ondemandhandler import OnDemandHandler, OnDemandTask
|
||||
from evennia.scripts.scripts import DoNothing, ExtendedLoopingCall
|
||||
from evennia.scripts.taskhandler import TASK_HANDLER
|
||||
from evennia.scripts.tickerhandler import TickerHandler
|
||||
from evennia.typeclasses.attributes import AttributeProperty
|
||||
from evennia.utils.create import create_script
|
||||
from evennia.utils.dbserialize import dbserialize
|
||||
from evennia.utils.test_resources import BaseEvenniaTest, EvenniaTest
|
||||
|
||||
|
||||
|
|
@ -382,6 +380,71 @@ class TestMonitorHandler(TestCase):
|
|||
self.assertEqual(self.handler.monitors[index][name], {})
|
||||
|
||||
|
||||
class TestTaskHandlerTask(TestCase):
|
||||
"""Test that TaskHandlerTask correctly handles stale references when task IDs are reused."""
|
||||
|
||||
def setUp(self):
|
||||
from twisted.internet import task as twisted_task
|
||||
|
||||
TASK_HANDLER.clock = twisted_task.Clock()
|
||||
TASK_HANDLER.clear()
|
||||
|
||||
def tearDown(self):
|
||||
TASK_HANDLER.clear()
|
||||
|
||||
def test_stale_reference_after_id_reuse(self):
|
||||
"""A stale TaskHandlerTask must not operate on a new task that reused its ID."""
|
||||
callback1 = mock.Mock(return_value="result1")
|
||||
callback2 = mock.Mock(return_value="result2")
|
||||
|
||||
# Create first task (gets ID 1)
|
||||
task1 = TASK_HANDLER.add(5, callback1)
|
||||
task1_id = task1.get_id()
|
||||
|
||||
# Complete and remove the first task so its ID is freed
|
||||
TASK_HANDLER.clock.advance(5)
|
||||
|
||||
# Create second task - should reuse ID 1
|
||||
task2 = TASK_HANDLER.add(5, callback2)
|
||||
self.assertEqual(task2.get_id(), task1_id)
|
||||
|
||||
# The stale reference (task1) must not affect the new task (task2)
|
||||
self.assertFalse(task1.exists())
|
||||
self.assertFalse(task1.active())
|
||||
self.assertFalse(task1.cancel())
|
||||
self.assertFalse(task1.remove())
|
||||
self.assertFalse(task1.do_task())
|
||||
self.assertFalse(task1.call())
|
||||
self.assertIsNone(task1.get_deferred())
|
||||
|
||||
# The new task must still be intact
|
||||
self.assertTrue(task2.exists())
|
||||
self.assertTrue(task2.active())
|
||||
|
||||
def test_valid_reference_works_normally(self):
|
||||
"""A valid TaskHandlerTask should work as expected."""
|
||||
callback = mock.Mock(return_value="result")
|
||||
task = TASK_HANDLER.add(5, callback)
|
||||
|
||||
self.assertTrue(task.exists())
|
||||
self.assertTrue(task.active())
|
||||
self.assertIsNotNone(task.get_deferred())
|
||||
|
||||
result = task.call()
|
||||
self.assertEqual(result, "result")
|
||||
callback.assert_called_once()
|
||||
|
||||
def test_is_valid_after_remove(self):
|
||||
"""After removing a task, its TaskHandlerTask reference should be invalid."""
|
||||
callback = mock.Mock()
|
||||
task = TASK_HANDLER.add(5, callback)
|
||||
|
||||
task.remove()
|
||||
self.assertFalse(task.exists())
|
||||
self.assertFalse(task.active())
|
||||
self.assertFalse(task.cancel())
|
||||
|
||||
|
||||
class TestOnDemandTask(EvenniaTest):
|
||||
"""
|
||||
Test the OnDemandTask class.
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue