Start to build unit tests for grid

This commit is contained in:
Griatch 2021-06-23 18:51:20 +02:00
parent a5f799c40f
commit aaa67218d6
4 changed files with 151 additions and 46 deletions

View file

@ -441,7 +441,7 @@ class MapLink:
next_target = xygrid[end_x][end_y]
except KeyError:
# check if we have some special action up our sleeve
next_target = self.at_empty_target(end_direction, xygrid)
next_target = self.at_empty_target(start_direction, end_direction, xygrid)
if not next_target:
raise MapParserError(
@ -679,7 +679,7 @@ class TeleporterMapLink(MapLink):
super().__init__(*args, **kwargs)
self.paired_teleporter = None
def at_empty_target(self, start_direction, xygrid):
def at_empty_target(self, start_direction, end_direction, xygrid):
"""
Called during traversal, when finding an unknown direction out of the link (same as
targeting a link at an empty spot on the grid). This will also search for
@ -807,7 +807,12 @@ class MapTransitionLink(TeleporterMapLink):
"""
if not self.paired_map_link:
grid = self.xymap.grid.grid
try:
grid = self.xymap.grid.grid
except AttributeError:
raise MapParserError(f"requires this map being set up within an XYZgrid. No grid "
"was found (maybe it was not passed during XYMap initialization?",
self)
try:
target_map = grid[self.target_map]
except KeyError:

View file

@ -8,7 +8,8 @@ from time import time
from random import randint
from unittest import TestCase
from parameterized import parameterized
from . import xymap
from . import xymap, xyzgrid, map_legend
MAP1 = """
@ -313,6 +314,31 @@ MAP11_DISPLAY = r"""
#-#
""".strip()
MAP12a = r"""
+ 0 1
1 #-T
|
0 #-#
+ 0 1
"""
MAP12b = r"""
+ 0 1
1 #-#
|
0 T-#
+ 0 1
"""
class TestMap1(TestCase):
"""
@ -946,7 +972,7 @@ class TestMapStressTest(TestCase):
mapobj.parse()
t0 = time()
mapobj._calculate_path_matrix()
mapobj.calculate_path_matrix()
t1 = time()
# print(f"pathfinder matrix for grid {Xmax}x{Ymax}: {t1 - t0}s")
@ -979,7 +1005,7 @@ class TestMapStressTest(TestCase):
mapobj.parse()
t0 = time()
mapobj._calculate_path_matrix()
mapobj.calculate_path_matrix()
t1 = time()
# print(f"pathfinder matrix for grid {Xmax}x{Ymax}: {t1 - t0}s")
@ -1000,3 +1026,57 @@ class TestMapStressTest(TestCase):
f"slower than expected {max_time}s.")
# map transitions
class Map12aTransition(map_legend.MapTransitionLink):
symbol = "T"
target_map = "map12b"
class Map12bTransition(map_legend.MapTransitionLink):
symbol = "T"
target_map = "map12a"
class TestXYZGrid(TestCase):
"""
Test the XYZGrid class and transitions between maps.
"""
def setUp(self):
self.grid, err = xyzgrid.XYZGrid.create("testgrid")
self.map_data12a = {
"map": MAP12a,
"name": "map12a",
"legend": {"T": Map12aTransition}
}
self.map_data12b = {
"map": MAP12b,
"name": "map12b",
"legend": {"T": Map12bTransition}
}
self.grid.add_maps(self.map_data12a, self.map_data12b)
def tearDown(self):
self.grid.delete()
@parameterized.expand([
((1, 0), (1, 1), ('e', 'nw', 'e')),
((1, 1), (0, 0), ('w', 'se', 'w')),
])
def test_shortest_path(self, startcoord, endcoord, expected_directions):
"""
test shortest-path calculations throughout the grid.
"""
directions, _ = self.grid.get('map12a').get_shortest_path(startcoord, endcoord)
self.assertEqual(expected_directions, tuple(directions))
def test_transition(self):
"""
Test transition.
"""

View file

@ -467,7 +467,7 @@ class XYMap:
points, xmin, xmax, ymin, ymax = _scan_neighbors(center_node, [], dist=dist)
return list(set(points)), xmin, xmax, ymin, ymax
def _calculate_path_matrix(self):
def calculate_path_matrix(self):
"""
Solve the pathfinding problem using Dijkstra's algorithm. This will try to
load the solution from disk if possible.
@ -639,7 +639,7 @@ class XYMap:
f"{endnode}. They must both be MapNodes (not Links)")
if self.pathfinding_routes is None:
self._calculate_path_matrix()
self.calculate_path_matrix()
pathfinding_routes = self.pathfinding_routes
node_index_map = self.node_index_map

View file

@ -29,65 +29,97 @@ class XYZGrid(DefaultScript):
"""
def at_script_creation(self):
"""
What we store persistently is the module-paths to each map.
What we store persistently is data used to create each map (the legends, names etc)
"""
self.db.map_data = {}
@property
def grid(self):
if self.ndb.grid is None:
self.reload()
return self.ndb.grid
def get(self, mapname, default=None):
return self.grid.get(mapname, default)
def reload(self):
"""
Reload the grid. This is done on a server reload and is also necessary if adding a new map
since this may introduce new between-map traversals.
Reload and rebuild the grid. This is done on a server reload and is also necessary if adding
a new map since this may introduce new between-map traversals.
"""
# build the nodes of each map
for name, xymap in self.grid:
logger.log_info("[grid] (Re)loading grid ...")
grid = {}
nmaps = 0
# generate all Maps - this will also initialize their components
# and bake any pathfinding paths (or load from disk-cache)
for mapname, mapdata in self.db.map_data.items():
logger.log_info(f"[grid] Loading map '{mapname}'...")
xymap = XYMap(dict(mapdata), name=mapname, grid=self)
xymap.parse_first_pass()
# link everything together
for name, xymap in self.grid:
xymap.parse_second_pass()
grid[mapname] = xymap
nmaps += 1
def add_map(self, mapdata, new=True):
# link maps together across grid
logger.log_info("[grid] Link {nmaps} maps (may be slow first time a map has changed) ...")
for name, xymap in grid.items():
xymap.parse_second_pass()
xymap.calculate_path_matrix()
# store
self.ndb.grid = grid
logger.log_info(f"[grid] Loaded and linked {nmaps} map(s).")
def at_init(self):
"""
Add new map to the grid.
Called when the script loads into memory (on creation or after a reload). This will load all
map data into memory.
"""
self.reload()
def add_maps(self, *mapdatas):
"""
Add map or maps to the grid.
Args:
mapdata (dict): A structure `{"map": <mapstr>, "legend": <legenddict>,
"name": <name>, "prototypes": <dict-of-dicts>}`. The `prototypes are
*mapdatas (dict): Each argument is a dict structure
`{"map": <mapstr>, "legend": <legenddict>, "name": <name>,
"prototypes": <dict-of-dicts>}`. The `prototypes are
coordinate-specific overrides for nodes/links on the map, keyed with their
(X,Y) coordinate (use .5 for link-positions between nodes).
new (bool, optional): If the data should be resaved.
(X,Y) coordinate.
Raises:
RuntimeError: If mapdata is malformed.
Notes:
This will assume that all added maps produce a complete set (that is, they are correctly
and completely linked together with each other and/or with existing maps). So
this will automatically trigger `.reload()` to rebuild the grid.
After this, you need to run `.sync_to_grid` to make the new map actually
available in-game.
"""
name = mapdata.get('name')
if not name:
raise RuntimeError("XYZGrid.add_map data must contain 'name'.")
for mapdata in mapdatas:
name = mapdata.get('name')
if not name:
raise RuntimeError("XYZGrid.add_map data must contain 'name'.")
# this will raise MapErrors if there are issues with the map
self.grid[name] = XYMap(mapdata, name=name, grid=self)
if new:
self.db.map_data[name] = mapdata
def remove_map(self, zcoord, remove_objects=False):
def remove_map(self, mapname, remove_objects=False):
"""
Remove a map from the grid.
Args:
name (str): The map to remove.
mapname (str): The map to remove.
remove_objects (bool, optional): If the synced database objects (rooms/exits) should
be removed alongside this map.
"""
if zcoord in self.grid:
if mapname in self.db.map_data:
self.db.map_data.pop(zcoord)
self.grid.pop(zcoord)
self.reload()
if remove_objects:
pass
@ -116,7 +148,7 @@ class XYZGrid(DefaultScript):
if z is None:
xymaps = self.grid
elif z in self.grid:
elif z in self.ndb.grid:
xymaps = [self.grid[z]]
else:
raise RuntimeError(f"The 'z' coordinate/name '{z}' is not found on the grid.")
@ -132,15 +164,3 @@ class XYZGrid(DefaultScript):
for node in synced:
node.sync_links_to_grid()
def at_init(self):
"""
Called when the script loads into memory after a reload. This will load all map data into
memory.
"""
nmaps = 0
for mapname, mapdata in self.db.map_data:
self.add_map(mapdata, new=False)
nmaps += 1
self.reload()
logger.log_info(f"Loaded {nmaps} map(s) onto the grid.")