From 7bd8fe0a34898c38dac062aef3bb22ecdc96317d Mon Sep 17 00:00:00 2001 From: Griatch Date: Sat, 3 Feb 2018 19:42:24 +0100 Subject: [PATCH] get_by_tag manager can now query for multiple tag/category combinations --- evennia/typeclasses/managers.py | 80 ++++++++++++--------------------- evennia/typeclasses/tests.py | 44 ++++++++++++++++-- 2 files changed, 69 insertions(+), 55 deletions(-) diff --git a/evennia/typeclasses/managers.py b/evennia/typeclasses/managers.py index ce84798dec..67eb9e065b 100644 --- a/evennia/typeclasses/managers.py +++ b/evennia/typeclasses/managers.py @@ -221,29 +221,6 @@ class TypedObjectManager(idmapper.manager.SharedMemoryManager): """ return self.get_tag(key=key, category=category, obj=obj, tagtype="alias") -# @returns_typeclass_list -# def get_by_tag(self, key=None, category=None, tagtype=None): -# """ -# Return objects having tags with a given key or category or -# combination of the two. -# -# Args: -# key (str, optional): Tag key. Not case sensitive. -# category (str, optional): Tag category. Not case sensitive. -# tagtype (str or None, optional): 'type' of Tag, by default -# this is either `None` (a normal Tag), `alias` or -# `permission`. -# Returns: -# objects (list): Objects with matching tag. -# """ -# dbmodel = self.model.__dbclass__.__name__.lower() -# query = [("db_tags__db_tagtype", tagtype), ("db_tags__db_model", dbmodel)] -# if key: -# query.append(("db_tags__db_key", key.lower())) -# if category: -# query.append(("db_tags__db_category", category.lower())) -# return self.filter(**dict(query)) - def get_by_tag(self, key=None, category=None, tagtype=None): """ Return objects having tags with a given key or category or combination of the two. @@ -253,7 +230,8 @@ class TypedObjectManager(idmapper.manager.SharedMemoryManager): key (str or list, optional): Tag key or list of keys. Not case sensitive. category (str or list, optional): Tag category. Not case sensitive. If `key` is a list, a single category can either apply to all keys in that list or this - must be a list matching the `key` list element by element. + must be a list matching the `key` list element by element. If no `key` is given, + all objects with tags of this category are returned. tagtype (str, optional): 'type' of Tag, by default this is either `None` (a normal Tag), `alias` or `permission`. This always apply to all queried tags. @@ -266,39 +244,37 @@ class TypedObjectManager(idmapper.manager.SharedMemoryManager): than `key`. """ - keys = make_iter(key) - categories = make_iter(category) + if not (key or category): + return [] + + keys = make_iter(key) if key else [] + categories = make_iter(category) if category else [] n_keys = len(keys) n_categories = len(categories) dbmodel = self.model.__dbclass__.__name__.lower() - if n_keys > 1: - if n_categories == 1: - category = categories[0] - query = Q() - for key in keys: - query = query & \ - Q(db_tags__db_tagtype=tagtype.lower() if tagtype else tagtype, - db_tags__db_category=category.lower() if category else category, - db_tags__db_model=dbmodel, - db_tags__db_key=key.lower()) - print "Query:", query - else: - query = Q(db_tags__db_tagtype=tagtype.lower(), - db_tags__db_model=dbmodel) - for ikey, key in keys: - category = categories[ikey] - category = category.lower() if category else category - query = query & Q(db_tags__db_key=key.lower(), - db_tags__db_category=category) - return self.filter(query) + query = self.filter(db_tags__db_tagtype__iexact=tagtype, + db_tags__db_model__iexact=dbmodel).distinct() + + if n_keys > 0: + # keys and/or categories given + if n_categories == 0: + categories = [None for _ in range(n_keys)] + elif n_categories == 1 and n_keys > 1: + cat = categories[0] + categories = [cat for _ in range(n_keys)] + elif 1 < n_categories < n_keys: + raise IndexError("get_by_tag needs a single category or a list of categories " + "the same length as the list of tags.") + for ikey, key in enumerate(keys): + query = query.filter(db_tags__db_key__iexact=key, + db_tags__db_category__iexact=categories[ikey]) else: - query = [("db_tags__db_tagtype", tagtype), ("db_tags__db_model", dbmodel)] - if key: - query.append(("db_tags__db_key", keys[0].lower())) - if category: - query.append(("db_tags__db_category", categories[0].lower())) - return self.filter(**dict(query)) + # only one or more categories given + for category in categories: + query = query.filter(db_tags__db_category__iexact=category) + + return query def get_by_permission(self, key=None, category=None): """ diff --git a/evennia/typeclasses/tests.py b/evennia/typeclasses/tests.py index 0885c35d40..b4a2361aae 100644 --- a/evennia/typeclasses/tests.py +++ b/evennia/typeclasses/tests.py @@ -12,10 +12,48 @@ from evennia.utils.test_resources import EvenniaTest class TestTypedObjectManager(EvenniaTest): def _manager(self, methodname, *args, **kwargs): - return getattr(self.obj1.__class__.objects, methodname)(*args, **kwargs) + return list(getattr(self.obj1.__class__.objects, methodname)(*args, **kwargs)) def test_get_by_tag_no_category(self): self.obj1.tags.add("tag1") + self.obj1.tags.add("tag2") + self.obj1.tags.add("tag2c") self.obj2.tags.add("tag2") - self.obj2.tags.add("tag3") - self.assertEquals(list(self._manager("get_by_tag", "tag1")), [self.obj1l]) + self.obj2.tags.add("tag2a") + self.obj2.tags.add("tag2b") + self.obj2.tags.add("tag3 with spaces") + self.obj2.tags.add("tag4") + self.obj2.tags.add("tag2c") + self.assertEquals(self._manager("get_by_tag", "tag1"), [self.obj1]) + self.assertEquals(self._manager("get_by_tag", "tag2"), [self.obj1, self.obj2]) + self.assertEquals(self._manager("get_by_tag", "tag2a"), [self.obj2]) + self.assertEquals(self._manager("get_by_tag", "tag3 with spaces"), [self.obj2]) + self.assertEquals(self._manager("get_by_tag", ["tag2a", "tag2b"]), [self.obj2]) + self.assertEquals(self._manager("get_by_tag", ["tag2a", "tag1"]), []) + self.assertEquals(self._manager("get_by_tag", ["tag2a", "tag4", "tag2c"]), [self.obj2]) + + def test_get_by_tag_and_category(self): + self.obj1.tags.add("tag5", "category1") + self.obj1.tags.add("tag6", ) + self.obj1.tags.add("tag7", "category1") + self.obj1.tags.add("tag6", "category3") + self.obj1.tags.add("tag7", "category4") + self.obj2.tags.add("tag5", "category1") + self.obj2.tags.add("tag5", "category2") + self.obj2.tags.add("tag6", "category3") + self.obj2.tags.add("tag7", "category1") + self.obj2.tags.add("tag7", "category5") + self.assertEquals(self._manager("get_by_tag", "tag5", "category1"), [self.obj1, self.obj2]) + self.assertEquals(self._manager("get_by_tag", "tag6", "category1"), []) + self.assertEquals(self._manager("get_by_tag", "tag6", "category3"), [self.obj1, self.obj2]) + self.assertEquals(self._manager("get_by_tag", ["tag5", "tag6"], + ["category1", "category3"]), [self.obj1, self.obj2]) + self.assertEquals(self._manager("get_by_tag", ["tag5", "tag7"], + "category1"), [self.obj1, self.obj2]) + self.assertEquals(self._manager("get_by_tag", category="category1"), [self.obj1, self.obj2]) + self.assertEquals(self._manager("get_by_tag", category="category2"), [self.obj2]) + self.assertEquals(self._manager("get_by_tag", category=["category1", "category3"]), + [self.obj1, self.obj2]) + self.assertEquals(self._manager("get_by_tag", category=["category1", "category2"]), + [self.obj2]) + self.assertEquals(self._manager("get_by_tag", category=["category5", "category4"]), [])