from playhouse.sqlite_ext import VirtualModel, VirtualIntegerField, VirtualCharField def ClosureTable(model_class, referencing_class = None, foreign_key=None, id_column = None): """Model factory for the transitive closure extension.""" if referencing_class is None: referencing_class = model_class if foreign_key is None: for field_obj in model_class._meta.rel.values(): if field_obj.rel_model is model_class: foreign_key = field_obj break else: raise ValueError('Unable to find self-referential foreign key.') primary_key = model_class._meta.primary_key if id_column is None: id_column = primary_key class BaseClosureTable(VirtualModel): depth = VirtualIntegerField() id = VirtualIntegerField() idcolumn = VirtualCharField() parentcolumn = VirtualCharField() root = VirtualIntegerField() tablename = VirtualCharField() class Meta: extension_module = 'transitive_closure' @classmethod def descendants(cls, node, depth=None, include_node=False): query = (model_class .select(model_class, cls.depth.alias('depth')) .join(cls, on=(primary_key == cls.id)) .where(cls.root == node) .naive()) if depth is not None: query = query.where(cls.depth == depth) elif not include_node: query = query.where(cls.depth > 0) return query @classmethod def ancestors(cls, node, depth=None, include_node=False): query = (model_class .select(model_class, cls.depth.alias('depth')) .join(cls, on=(primary_key == cls.root)) .where(cls.id == node) .naive()) if depth: query = query.where(cls.depth == depth) elif not include_node: query = query.where(cls.depth > 0) return query @classmethod def siblings(cls, node, include_node=False): if referencing_class is model_class: # self-join fk_value = node._data.get(foreign_key.name) query = model_class.select().where(foreign_key == fk_value) else: # siblings as given in reference_class siblings = (referencing_class .select(id_column) .join(cls, on=(foreign_key == cls.root)) .where((cls.id == node) & (cls.depth == 1))) # the according models query = (model_class .select() .where(primary_key << siblings) .naive()) if not include_node: query = query.where(primary_key != node) return query class Meta: database = referencing_class._meta.database extension_options = { 'tablename': referencing_class._meta.db_table, 'idcolumn': id_column.db_column, 'parentcolumn': foreign_key.db_column} primary_key = False name = '%sClosure' % model_class.__name__ return type(name, (BaseClosureTable,), {'Meta': Meta, '__module__': __name__})