from playhouse.sqlite_ext import VirtualModel, VirtualIntegerField, VirtualCharField from peewee import Field class EnumField(Field): db_field = 'enum' def __init__(self, enum_class, *args, **kwargs): super().__init__(*args, **kwargs) self.enum_class = enum_class def _enum_value(self, value): if isinstance(value, str): try: return self.enum_class[value.upper()] except KeyError: pass try: return self.enum_class(int(value)) except ValueError: raise ValueError("%r is not a valid %s" % (value, self.enum_class.__name__)) def db_value(self, value): if value is None: return value if isinstance(value, self.enum_class): return value.value # force check of enum value return self._enum_value(value).value def python_value(self, value): return value if value is None else self._enum_value(value) def ClosureTable(model_class, referencing_class = None, foreign_key=None, id_column = None): """Model factory for the transitive closure extension.""" if referencing_class is None: referencing_class = model_class if foreign_key is None: for field_obj in model_class._meta.rel.values(): if field_obj.rel_model is model_class: foreign_key = field_obj break else: raise ValueError('Unable to find self-referential foreign key.') primary_key = model_class._meta.primary_key if id_column is None: id_column = primary_key class BaseClosureTable(VirtualModel): depth = VirtualIntegerField() id = VirtualIntegerField() idcolumn = VirtualCharField() parentcolumn = VirtualCharField() root = VirtualIntegerField() tablename = VirtualCharField() class Meta: extension_module = 'transitive_closure' @classmethod def descendants(cls, node, depth=None, include_node=False): query = (model_class .select(model_class, cls.depth.alias('depth')) .join(cls, on=(primary_key == cls.id)) .where(cls.root == node) .naive()) if depth is not None: query = query.where(cls.depth == depth) elif not include_node: query = query.where(cls.depth > 0) return query @classmethod def ancestors(cls, node, depth=None, include_node=False): query = (model_class .select(model_class, cls.depth.alias('depth')) .join(cls, on=(primary_key == cls.root)) .where(cls.id == node) .naive()) if depth: query = query.where(cls.depth == depth) elif not include_node: query = query.where(cls.depth > 0) return query @classmethod def siblings(cls, node, include_node=False): if referencing_class is model_class: # self-join fk_value = node._data.get(foreign_key.name) query = model_class.select().where(foreign_key == fk_value) else: # siblings as given in reference_class siblings = (referencing_class .select(id_column) .join(cls, on=(foreign_key == cls.root)) .where((cls.id == node) & (cls.depth == 1))) # the according models query = (model_class .select() .where(primary_key << siblings) .naive()) if not include_node: query = query.where(primary_key != node) return query class Meta: database = referencing_class._meta.database extension_options = { 'tablename': referencing_class._meta.db_table, 'idcolumn': id_column.db_column, 'parentcolumn': foreign_key.db_column} primary_key = False name = '%sClosure' % model_class.__name__ return type(name, (BaseClosureTable,), {'Meta': Meta, '__module__': __name__})