diff options
Diffstat (limited to 'archivist/peewee_ext.py')
-rw-r--r-- | archivist/peewee_ext.py | 125 |
1 files changed, 125 insertions, 0 deletions
diff --git a/archivist/peewee_ext.py b/archivist/peewee_ext.py new file mode 100644 index 0000000..9fda66e --- /dev/null +++ b/archivist/peewee_ext.py @@ -0,0 +1,125 @@ +from playhouse.sqlite_ext import VirtualModel, VirtualIntegerField, VirtualCharField +from peewee import Field + +class EnumField(Field): + db_field = 'enum' + + def __init__(self, enum_class, *args, **kwargs): + super().__init__(*args, **kwargs) + self.enum_class = enum_class + + def _enum_value(self, value): + if isinstance(value, str): + try: + return self.enum_class[value.upper()] + except KeyError: + pass + + try: + return self.enum_class(int(value)) + except ValueError: + raise ValueError("%r is not a valid %s" % (value, self.enum_class.__name__)) + + def db_value(self, value): + if value is None: + return value + + if isinstance(value, self.enum_class): + return value.value + + # force check of enum value + return self._enum_value(value).value + + def python_value(self, value): + return value if value is None else self._enum_value(value) + +def ClosureTable(model_class, referencing_class = None, foreign_key=None, id_column = None): + """Model factory for the transitive closure extension.""" + if referencing_class is None: + referencing_class = model_class + + if foreign_key is None: + for field_obj in model_class._meta.rel.values(): + if field_obj.rel_model is model_class: + foreign_key = field_obj + break + else: + raise ValueError('Unable to find self-referential foreign key.') + + primary_key = model_class._meta.primary_key + + if id_column is None: + id_column = primary_key + + class BaseClosureTable(VirtualModel): + depth = VirtualIntegerField() + id = VirtualIntegerField() + idcolumn = VirtualCharField() + parentcolumn = VirtualCharField() + root = VirtualIntegerField() + tablename = VirtualCharField() + + class Meta: + extension_module = 'transitive_closure' + + @classmethod + def descendants(cls, node, depth=None, include_node=False): + query = (model_class + .select(model_class, cls.depth.alias('depth')) + .join(cls, on=(primary_key == cls.id)) + .where(cls.root == node) + .naive()) + if depth is not None: + query = query.where(cls.depth == depth) + elif not include_node: + query = query.where(cls.depth > 0) + return query + + @classmethod + def ancestors(cls, node, depth=None, include_node=False): + query = (model_class + .select(model_class, cls.depth.alias('depth')) + .join(cls, on=(primary_key == cls.root)) + .where(cls.id == node) + .naive()) + if depth: + query = query.where(cls.depth == depth) + elif not include_node: + query = query.where(cls.depth > 0) + return query + + @classmethod + def siblings(cls, node, include_node=False): + if referencing_class is model_class: + # self-join + fk_value = node._data.get(foreign_key.name) + query = model_class.select().where(foreign_key == fk_value) + else: + # siblings as given in reference_class + siblings = (referencing_class + .select(id_column) + .join(cls, on=(foreign_key == cls.root)) + .where((cls.id == node) & (cls.depth == 1))) + + # the according models + query = (model_class + .select() + .where(primary_key << siblings) + .naive()) + + if not include_node: + query = query.where(primary_key != node) + + return query + + class Meta: + database = referencing_class._meta.database + extension_options = { + 'tablename': referencing_class._meta.db_table, + 'idcolumn': id_column.db_column, + 'parentcolumn': foreign_key.db_column} + primary_key = False + + name = '%sClosure' % model_class.__name__ + return type(name, (BaseClosureTable,), {'Meta': Meta, '__module__': __name__}) + |