diff --git a/src/typeclasses/managers.py b/src/typeclasses/managers.py index 2c3b0afb9f..5cfd34d623 100644 --- a/src/typeclasses/managers.py +++ b/src/typeclasses/managers.py @@ -8,7 +8,7 @@ from django.db import models from django.db.models import Q from django.contrib.contenttypes.models import ContentType from src.utils import idmapper -from src.utils.utils import make_iter +from src.utils.utils import make_iter, variable_from_module from src.utils.dbserialize import to_pickle __all__ = ("AttributeManager", "TypedObjectManager") @@ -281,13 +281,44 @@ class TypedObjectManager(idmapper.manager.SharedMemoryManager): return dbtotals @returns_typeclass_list - def typeclass_search(self, typeclass): + def typeclass_search(self, typeclass, include_children=False, include_parents=False): """ - Searches through all objects returning those which has a certain - typeclass. If location is set, limit search to objects in - that location. + Searches through all objects returning those which has a + certain typeclass. If location is set, limit search to objects + in that location. + + typeclass - a typeclass class or a python path to a typeclass + include_children - return objects with given typeclass and all + children inheriting from this typeclass. + include_parents - return objects with given typeclass and all + parents to this typeclass + The include_children/parents keywords are mutually exclusive. """ + if callable(typeclass): cls = typeclass.__class__ typeclass = "%s.%s" % (cls.__module__, cls.__name__) - return self.filter(db_typeclass_path__exact=typeclass) + elif not isinstance(typeclass, basestring) and hasattr(typeclass, "path"): + typeclass = typeclass.path + + # query objects of exact typeclass + query = Q(db_typeclass_path__exact=typeclass) + + if include_children: + # build requests for child typeclass objects + clsmodule, clsname = typeclass.rsplit(".", 1) + cls = variable_from_module(clsmodule, clsname) + subclasses = cls.__subclasses__() + if subclasses: + for child in (child for child in subclasses if hasattr(child, "path")): + query = query | Q(db_typeclass_path__exact=child.path) + elif include_parents: + # build requests for parent typeclass objects + clsmodule, clsname = typeclass.rsplit(".", 1) + cls = variable_from_module(clsmodule, clsname) + parents = cls.__mro__ + if parents: + for parent in (parent for parent in parents if hasattr(parent, "path")): + query = query | Q(db_typeclass_path__exact=parent.path) + # actually query the database + return self.filter(query)