1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
|
from playhouse.sqlite_ext import VirtualModel, VirtualIntegerField, VirtualCharField
from peewee import Field, OP, DJANGO_MAP, ForeignKeyField, ReverseRelationDescriptor, Expression, Query
from itertools import starmap
from functools import reduce
import operator as op
def sqlite_tuple_in(fields, values):
"""SQLite does not support (foo, bar) IN ((1,2),(3,4)).
So we construct a '(foo = 1 AND bar = 2) OR (foo = 3 AND bar = 4)' monstrum."""
subqueries = (reduce(op.and_, starmap(op.eq, zip(fields, value_tuple))) for value_tuple in values)
return reduce(op.or_, subqueries)
def convert_dict_to_node(self, qdict):
accum = []
joins = []
relationship = (ForeignKeyField, ReverseRelationDescriptor)
for key, value in sorted(qdict.items()):
curr = self.model_class
if '__' in key and key.rsplit('__', 1)[1] in DJANGO_MAP:
key, op = key.rsplit('__', 1)
op = DJANGO_MAP[op]
elif value is None:
op = OP.IS
else:
op = OP.EQ
for piece in key.split('__'):
model_attr = getattr(curr, piece)
if value is not None and isinstance(model_attr, relationship):
curr = model_attr.rel_model
joins.append(model_attr)
accum.append(Expression(model_attr, op, value))
return accum, joins
Query.convert_dict_to_node = convert_dict_to_node
class EnumField(Field):
db_field = 'enum'
def __init__(self, enum_class, *args, **kwargs):
super().__init__(*args, **kwargs)
self.enum_class = enum_class
def _enum_value(self, value):
if isinstance(value, str):
try:
return self.enum_class[value.upper()]
except KeyError:
pass
try:
return self.enum_class(int(value))
except ValueError:
raise ValueError("%r is not a valid %s" % (value, self.enum_class.__name__))
def db_value(self, value):
if value is None:
return value
if isinstance(value, self.enum_class):
return value.value
# force check of enum value
return self._enum_value(value).value
def python_value(self, value):
return value if value is None else self._enum_value(value)
def ClosureTable(model_class, referencing_class = None, foreign_key=None, id_column = None):
"""Model factory for the transitive closure extension."""
if referencing_class is None:
referencing_class = model_class
if foreign_key is None:
for field_obj in model_class._meta.rel.values():
if field_obj.rel_model is model_class:
foreign_key = field_obj
break
else:
raise ValueError('Unable to find self-referential foreign key.')
primary_key = model_class._meta.primary_key
if id_column is None:
id_column = primary_key
class BaseClosureTable(VirtualModel):
depth = VirtualIntegerField()
id = VirtualIntegerField()
idcolumn = VirtualCharField()
parentcolumn = VirtualCharField()
root = VirtualIntegerField()
tablename = VirtualCharField()
class Meta:
extension_module = 'transitive_closure'
@classmethod
def descendants(cls, node, depth=None, include_node=False):
query = (model_class
.select(model_class, cls.depth.alias('depth'))
.join(cls, on=(primary_key == cls.id))
.where(cls.root == node)
.naive())
if depth is not None:
query = query.where(cls.depth == depth)
elif not include_node:
query = query.where(cls.depth > 0)
return query
@classmethod
def ancestors(cls, node, depth=None, include_node=False):
query = (model_class
.select(model_class, cls.depth.alias('depth'))
.join(cls, on=(primary_key == cls.root))
.where(cls.id == node)
.naive())
if depth:
query = query.where(cls.depth == depth)
elif not include_node:
query = query.where(cls.depth > 0)
return query
@classmethod
def siblings(cls, node, include_node=False):
if referencing_class is model_class:
# self-join
fk_value = node._data.get(foreign_key.name)
query = model_class.select().where(foreign_key == fk_value)
else:
# siblings as given in reference_class
siblings = (referencing_class
.select(id_column)
.join(cls, on=(foreign_key == cls.root))
.where((cls.id == node) & (cls.depth == 1)))
# the according models
query = (model_class
.select()
.where(primary_key << siblings)
.naive())
if not include_node:
query = query.where(primary_key != node)
return query
class Meta:
database = referencing_class._meta.database
extension_options = {
'tablename': referencing_class._meta.db_table,
'idcolumn': id_column.db_column,
'parentcolumn': foreign_key.db_column}
primary_key = False
name = '%sClosure' % model_class.__name__
return type(name, (BaseClosureTable,), {'Meta': Meta, '__module__': __name__})
|