summaryrefslogtreecommitdiff
path: root/archivist/peewee_ext.py
diff options
context:
space:
mode:
Diffstat (limited to 'archivist/peewee_ext.py')
-rw-r--r--archivist/peewee_ext.py125
1 files changed, 125 insertions, 0 deletions
diff --git a/archivist/peewee_ext.py b/archivist/peewee_ext.py
new file mode 100644
index 0000000..9fda66e
--- /dev/null
+++ b/archivist/peewee_ext.py
@@ -0,0 +1,125 @@
+from playhouse.sqlite_ext import VirtualModel, VirtualIntegerField, VirtualCharField
+from peewee import Field
+
+class EnumField(Field):
+ db_field = 'enum'
+
+ def __init__(self, enum_class, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.enum_class = enum_class
+
+ def _enum_value(self, value):
+ if isinstance(value, str):
+ try:
+ return self.enum_class[value.upper()]
+ except KeyError:
+ pass
+
+ try:
+ return self.enum_class(int(value))
+ except ValueError:
+ raise ValueError("%r is not a valid %s" % (value, self.enum_class.__name__))
+
+ def db_value(self, value):
+ if value is None:
+ return value
+
+ if isinstance(value, self.enum_class):
+ return value.value
+
+ # force check of enum value
+ return self._enum_value(value).value
+
+ def python_value(self, value):
+ return value if value is None else self._enum_value(value)
+
+def ClosureTable(model_class, referencing_class = None, foreign_key=None, id_column = None):
+ """Model factory for the transitive closure extension."""
+ if referencing_class is None:
+ referencing_class = model_class
+
+ if foreign_key is None:
+ for field_obj in model_class._meta.rel.values():
+ if field_obj.rel_model is model_class:
+ foreign_key = field_obj
+ break
+ else:
+ raise ValueError('Unable to find self-referential foreign key.')
+
+ primary_key = model_class._meta.primary_key
+
+ if id_column is None:
+ id_column = primary_key
+
+ class BaseClosureTable(VirtualModel):
+ depth = VirtualIntegerField()
+ id = VirtualIntegerField()
+ idcolumn = VirtualCharField()
+ parentcolumn = VirtualCharField()
+ root = VirtualIntegerField()
+ tablename = VirtualCharField()
+
+ class Meta:
+ extension_module = 'transitive_closure'
+
+ @classmethod
+ def descendants(cls, node, depth=None, include_node=False):
+ query = (model_class
+ .select(model_class, cls.depth.alias('depth'))
+ .join(cls, on=(primary_key == cls.id))
+ .where(cls.root == node)
+ .naive())
+ if depth is not None:
+ query = query.where(cls.depth == depth)
+ elif not include_node:
+ query = query.where(cls.depth > 0)
+ return query
+
+ @classmethod
+ def ancestors(cls, node, depth=None, include_node=False):
+ query = (model_class
+ .select(model_class, cls.depth.alias('depth'))
+ .join(cls, on=(primary_key == cls.root))
+ .where(cls.id == node)
+ .naive())
+ if depth:
+ query = query.where(cls.depth == depth)
+ elif not include_node:
+ query = query.where(cls.depth > 0)
+ return query
+
+ @classmethod
+ def siblings(cls, node, include_node=False):
+ if referencing_class is model_class:
+ # self-join
+ fk_value = node._data.get(foreign_key.name)
+ query = model_class.select().where(foreign_key == fk_value)
+ else:
+ # siblings as given in reference_class
+ siblings = (referencing_class
+ .select(id_column)
+ .join(cls, on=(foreign_key == cls.root))
+ .where((cls.id == node) & (cls.depth == 1)))
+
+ # the according models
+ query = (model_class
+ .select()
+ .where(primary_key << siblings)
+ .naive())
+
+ if not include_node:
+ query = query.where(primary_key != node)
+
+ return query
+
+ class Meta:
+ database = referencing_class._meta.database
+ extension_options = {
+ 'tablename': referencing_class._meta.db_table,
+ 'idcolumn': id_column.db_column,
+ 'parentcolumn': foreign_key.db_column}
+ primary_key = False
+
+ name = '%sClosure' % model_class.__name__
+ return type(name, (BaseClosureTable,), {'Meta': Meta, '__module__': __name__})
+