summaryrefslogtreecommitdiff
path: root/archivist/closure.py
blob: 01fdc7589b822f70f800d8f197a04cc7f0fb3296 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
from playhouse.sqlite_ext import VirtualModel, VirtualIntegerField, VirtualCharField

def ClosureTable(model_class, referencing_class = None, foreign_key=None, id_column = None):
    """Model factory for the transitive closure extension."""
    if referencing_class is None:
        referencing_class = model_class

    if foreign_key is None:
        for field_obj in model_class._meta.rel.values():
            if field_obj.rel_model is model_class:
                foreign_key = field_obj
                break
        else:
            raise ValueError('Unable to find self-referential foreign key.')

    primary_key = model_class._meta.primary_key

    if id_column is None:
        id_column = primary_key

    class BaseClosureTable(VirtualModel):
        depth = VirtualIntegerField()
        id = VirtualIntegerField()
        idcolumn = VirtualCharField()
        parentcolumn = VirtualCharField()
        root = VirtualIntegerField()
        tablename = VirtualCharField()

        class Meta:
            extension_module = 'transitive_closure'

        @classmethod
        def descendants(cls, node, depth=None, include_node=False):
            query = (model_class
                     .select(model_class, cls.depth.alias('depth'))
                     .join(cls, on=(primary_key == cls.id))
                     .where(cls.root == node)
                     .naive())
            if depth is not None:
                query = query.where(cls.depth == depth)
            elif not include_node:
                query = query.where(cls.depth > 0)
            return query

        @classmethod
        def ancestors(cls, node, depth=None, include_node=False):
            query = (model_class
                     .select(model_class, cls.depth.alias('depth'))
                     .join(cls, on=(primary_key == cls.root))
                     .where(cls.id == node)
                     .naive())
            if depth:
                query = query.where(cls.depth == depth)
            elif not include_node:
                query = query.where(cls.depth > 0)
            return query

        @classmethod
        def siblings(cls, node, include_node=False):
            if referencing_class is model_class:
                # self-join
                fk_value = node._data.get(foreign_key.name)
                query = model_class.select().where(foreign_key == fk_value)
            else:
                # siblings as given in reference_class
                siblings = (referencing_class
                        .select(id_column)
                        .join(cls, on=(foreign_key == cls.root))
                        .where((cls.id == node) & (cls.depth == 1)))

                # the according models
                query = (model_class
                     .select()
                     .where(primary_key << siblings)
                     .naive())

            if not include_node: 
                query = query.where(primary_key != node)

            return query
    
    class Meta:
        database = referencing_class._meta.database
        extension_options = {
            'tablename': referencing_class._meta.db_table,
            'idcolumn': id_column.db_column,
            'parentcolumn': foreign_key.db_column}
        primary_key = False

    name = '%sClosure' % model_class.__name__
    return type(name, (BaseClosureTable,), {'Meta': Meta, '__module__': __name__})