From ef3cda14736df8b8a9f7ff4f022b9a8250713bd9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ren=C3=A9=20=27Necoro=27=20Neumann?= Date: Sun, 26 Feb 2017 18:43:25 +0100 Subject: Refined the EnumField --- archivist/closure.py | 92 ----------------------------------- archivist/model.py | 25 +--------- archivist/peewee_ext.py | 125 ++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 126 insertions(+), 116 deletions(-) delete mode 100644 archivist/closure.py create mode 100644 archivist/peewee_ext.py diff --git a/archivist/closure.py b/archivist/closure.py deleted file mode 100644 index 01fdc75..0000000 --- a/archivist/closure.py +++ /dev/null @@ -1,92 +0,0 @@ -from playhouse.sqlite_ext import VirtualModel, VirtualIntegerField, VirtualCharField - -def ClosureTable(model_class, referencing_class = None, foreign_key=None, id_column = None): - """Model factory for the transitive closure extension.""" - if referencing_class is None: - referencing_class = model_class - - if foreign_key is None: - for field_obj in model_class._meta.rel.values(): - if field_obj.rel_model is model_class: - foreign_key = field_obj - break - else: - raise ValueError('Unable to find self-referential foreign key.') - - primary_key = model_class._meta.primary_key - - if id_column is None: - id_column = primary_key - - class BaseClosureTable(VirtualModel): - depth = VirtualIntegerField() - id = VirtualIntegerField() - idcolumn = VirtualCharField() - parentcolumn = VirtualCharField() - root = VirtualIntegerField() - tablename = VirtualCharField() - - class Meta: - extension_module = 'transitive_closure' - - @classmethod - def descendants(cls, node, depth=None, include_node=False): - query = (model_class - .select(model_class, cls.depth.alias('depth')) - .join(cls, on=(primary_key == cls.id)) - .where(cls.root == node) - .naive()) - if depth is not None: - query = query.where(cls.depth == depth) - elif not include_node: - query = query.where(cls.depth > 0) - return query - - @classmethod - def ancestors(cls, node, depth=None, include_node=False): - query = (model_class - .select(model_class, cls.depth.alias('depth')) - .join(cls, on=(primary_key == cls.root)) - .where(cls.id == node) - .naive()) - if depth: - query = query.where(cls.depth == depth) - elif not include_node: - query = query.where(cls.depth > 0) - return query - - @classmethod - def siblings(cls, node, include_node=False): - if referencing_class is model_class: - # self-join - fk_value = node._data.get(foreign_key.name) - query = model_class.select().where(foreign_key == fk_value) - else: - # siblings as given in reference_class - siblings = (referencing_class - .select(id_column) - .join(cls, on=(foreign_key == cls.root)) - .where((cls.id == node) & (cls.depth == 1))) - - # the according models - query = (model_class - .select() - .where(primary_key << siblings) - .naive()) - - if not include_node: - query = query.where(primary_key != node) - - return query - - class Meta: - database = referencing_class._meta.database - extension_options = { - 'tablename': referencing_class._meta.db_table, - 'idcolumn': id_column.db_column, - 'parentcolumn': foreign_key.db_column} - primary_key = False - - name = '%sClosure' % model_class.__name__ - return type(name, (BaseClosureTable,), {'Meta': Meta, '__module__': __name__}) - diff --git a/archivist/model.py b/archivist/model.py index d1ef6c7..fd1bec3 100644 --- a/archivist/model.py +++ b/archivist/model.py @@ -8,7 +8,7 @@ from enum import Enum, unique from pkg_resources import resource_filename from .prefixes import query_pseudo_prefix -from .closure import ClosureTable +from .peewee_ext import ClosureTable, EnumField db = SqliteExtDatabase('test.db', pragmas=[('foreign_keys', 'ON')]) db.load_extension(resource_filename(__name__, 'sqlext/closure')) @@ -31,29 +31,6 @@ class BaseModel(Model): class Meta: database = db -class EnumField(Field): - db_field = 'enum' - - def __init__(self, enum_class, *args, **kwargs): - super().__init__(*args, **kwargs) - self.enum_class = enum_class - - def _enum_value(self, value): - return self.enum_class(int(value)) - - def db_value(self, value): - if value is None: - return value - - if isinstance(value, self.enum_class): - return value.value - - # force check of enum value - return self._enum_value(value).value - - def python_value(self, value): - return value if value is None else self._enum_value(value) - @table class Document(BaseModel): @unique diff --git a/archivist/peewee_ext.py b/archivist/peewee_ext.py new file mode 100644 index 0000000..9fda66e --- /dev/null +++ b/archivist/peewee_ext.py @@ -0,0 +1,125 @@ +from playhouse.sqlite_ext import VirtualModel, VirtualIntegerField, VirtualCharField +from peewee import Field + +class EnumField(Field): + db_field = 'enum' + + def __init__(self, enum_class, *args, **kwargs): + super().__init__(*args, **kwargs) + self.enum_class = enum_class + + def _enum_value(self, value): + if isinstance(value, str): + try: + return self.enum_class[value.upper()] + except KeyError: + pass + + try: + return self.enum_class(int(value)) + except ValueError: + raise ValueError("%r is not a valid %s" % (value, self.enum_class.__name__)) + + def db_value(self, value): + if value is None: + return value + + if isinstance(value, self.enum_class): + return value.value + + # force check of enum value + return self._enum_value(value).value + + def python_value(self, value): + return value if value is None else self._enum_value(value) + +def ClosureTable(model_class, referencing_class = None, foreign_key=None, id_column = None): + """Model factory for the transitive closure extension.""" + if referencing_class is None: + referencing_class = model_class + + if foreign_key is None: + for field_obj in model_class._meta.rel.values(): + if field_obj.rel_model is model_class: + foreign_key = field_obj + break + else: + raise ValueError('Unable to find self-referential foreign key.') + + primary_key = model_class._meta.primary_key + + if id_column is None: + id_column = primary_key + + class BaseClosureTable(VirtualModel): + depth = VirtualIntegerField() + id = VirtualIntegerField() + idcolumn = VirtualCharField() + parentcolumn = VirtualCharField() + root = VirtualIntegerField() + tablename = VirtualCharField() + + class Meta: + extension_module = 'transitive_closure' + + @classmethod + def descendants(cls, node, depth=None, include_node=False): + query = (model_class + .select(model_class, cls.depth.alias('depth')) + .join(cls, on=(primary_key == cls.id)) + .where(cls.root == node) + .naive()) + if depth is not None: + query = query.where(cls.depth == depth) + elif not include_node: + query = query.where(cls.depth > 0) + return query + + @classmethod + def ancestors(cls, node, depth=None, include_node=False): + query = (model_class + .select(model_class, cls.depth.alias('depth')) + .join(cls, on=(primary_key == cls.root)) + .where(cls.id == node) + .naive()) + if depth: + query = query.where(cls.depth == depth) + elif not include_node: + query = query.where(cls.depth > 0) + return query + + @classmethod + def siblings(cls, node, include_node=False): + if referencing_class is model_class: + # self-join + fk_value = node._data.get(foreign_key.name) + query = model_class.select().where(foreign_key == fk_value) + else: + # siblings as given in reference_class + siblings = (referencing_class + .select(id_column) + .join(cls, on=(foreign_key == cls.root)) + .where((cls.id == node) & (cls.depth == 1))) + + # the according models + query = (model_class + .select() + .where(primary_key << siblings) + .naive()) + + if not include_node: + query = query.where(primary_key != node) + + return query + + class Meta: + database = referencing_class._meta.database + extension_options = { + 'tablename': referencing_class._meta.db_table, + 'idcolumn': id_column.db_column, + 'parentcolumn': foreign_key.db_column} + primary_key = False + + name = '%sClosure' % model_class.__name__ + return type(name, (BaseClosureTable,), {'Meta': Meta, '__module__': __name__}) + -- cgit v1.2.3