From b00b51565e4f3aefd6e86c1c9d9c46f70711887b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ren=C3=A9=20=27Necoro=27=20Neumann?= Date: Sun, 26 Feb 2017 17:30:16 +0100 Subject: Introduce the closure --- archivist/closure.py | 92 ++++++++++++++++++++++++++++++++++++++++++++++++++++ archivist/model.py | 13 ++++++-- 2 files changed, 102 insertions(+), 3 deletions(-) create mode 100644 archivist/closure.py diff --git a/archivist/closure.py b/archivist/closure.py new file mode 100644 index 0000000..01fdc75 --- /dev/null +++ b/archivist/closure.py @@ -0,0 +1,92 @@ +from playhouse.sqlite_ext import VirtualModel, VirtualIntegerField, VirtualCharField + +def ClosureTable(model_class, referencing_class = None, foreign_key=None, id_column = None): + """Model factory for the transitive closure extension.""" + if referencing_class is None: + referencing_class = model_class + + if foreign_key is None: + for field_obj in model_class._meta.rel.values(): + if field_obj.rel_model is model_class: + foreign_key = field_obj + break + else: + raise ValueError('Unable to find self-referential foreign key.') + + primary_key = model_class._meta.primary_key + + if id_column is None: + id_column = primary_key + + class BaseClosureTable(VirtualModel): + depth = VirtualIntegerField() + id = VirtualIntegerField() + idcolumn = VirtualCharField() + parentcolumn = VirtualCharField() + root = VirtualIntegerField() + tablename = VirtualCharField() + + class Meta: + extension_module = 'transitive_closure' + + @classmethod + def descendants(cls, node, depth=None, include_node=False): + query = (model_class + .select(model_class, cls.depth.alias('depth')) + .join(cls, on=(primary_key == cls.id)) + .where(cls.root == node) + .naive()) + if depth is not None: + query = query.where(cls.depth == depth) + elif not include_node: + query = query.where(cls.depth > 0) + return query + + @classmethod + def ancestors(cls, node, depth=None, include_node=False): + query = (model_class + .select(model_class, cls.depth.alias('depth')) + .join(cls, on=(primary_key == cls.root)) + .where(cls.id == node) + .naive()) + if depth: + query = query.where(cls.depth == depth) + elif not include_node: + query = query.where(cls.depth > 0) + return query + + @classmethod + def siblings(cls, node, include_node=False): + if referencing_class is model_class: + # self-join + fk_value = node._data.get(foreign_key.name) + query = model_class.select().where(foreign_key == fk_value) + else: + # siblings as given in reference_class + siblings = (referencing_class + .select(id_column) + .join(cls, on=(foreign_key == cls.root)) + .where((cls.id == node) & (cls.depth == 1))) + + # the according models + query = (model_class + .select() + .where(primary_key << siblings) + .naive()) + + if not include_node: + query = query.where(primary_key != node) + + return query + + class Meta: + database = referencing_class._meta.database + extension_options = { + 'tablename': referencing_class._meta.db_table, + 'idcolumn': id_column.db_column, + 'parentcolumn': foreign_key.db_column} + primary_key = False + + name = '%sClosure' % model_class.__name__ + return type(name, (BaseClosureTable,), {'Meta': Meta, '__module__': __name__}) + diff --git a/archivist/model.py b/archivist/model.py index 79b0514..dd4c87e 100644 --- a/archivist/model.py +++ b/archivist/model.py @@ -1,12 +1,16 @@ from peewee import * from playhouse.fields import CompressedField from playhouse.hybrid import * +from playhouse.sqlite_ext import SqliteExtDatabase import datetime +from pkg_resources import resource_filename from .prefixes import query_pseudo_prefix +from .closure import ClosureTable -db = SqliteDatabase('test.db', pragmas=[('foreign_keys', 'ON')]) +db = SqliteExtDatabase('test.db', pragmas=[('foreign_keys', 'ON')]) +db.load_extension(resource_filename(__name__, 'sqlext/closure')) __tables__ = [] __all__ = ['create_tables', 'drop_tables'] @@ -101,11 +105,14 @@ class DocumentTag(BaseModel): @table class TagImplications(BaseModel): - tag = ForeignKeyField(Tag) - implies_tag = ForeignKeyField(Tag, related_name = 'implications') + tag = ForeignKeyField(Tag, related_name = 'implications') + implies_tag = ForeignKeyField(Tag, related_name = '_implied_by') class Meta: primary_key = CompositeKey('tag', 'implies_tag') def __repr__(self): return "<%s %d --> %d>" % (self.__class__.__name__, self.tag_id, self.implies_tag_id) + +TagClosure = ClosureTable(Tag, TagImplications, TagImplications.implies_tag, TagImplications.tag) +table(TagClosure) -- cgit v1.2.3