summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRené 'Necoro' Neumann <necoro@necoro.net>2017-02-26 18:43:25 +0100
committerRené 'Necoro' Neumann <necoro@necoro.net>2017-02-26 18:43:25 +0100
commitef3cda14736df8b8a9f7ff4f022b9a8250713bd9 (patch)
treec63fe2b4d0146219c3e1d0d5d51492ffda8c5c3e
parent247a64165b014960be2d26f9d7a16559b36ac8bf (diff)
downloadarchivist-ef3cda14736df8b8a9f7ff4f022b9a8250713bd9.tar.gz
archivist-ef3cda14736df8b8a9f7ff4f022b9a8250713bd9.tar.bz2
archivist-ef3cda14736df8b8a9f7ff4f022b9a8250713bd9.zip
Refined the EnumField
-rw-r--r--archivist/model.py25
-rw-r--r--archivist/peewee_ext.py (renamed from archivist/closure.py)33
2 files changed, 34 insertions, 24 deletions
diff --git a/archivist/model.py b/archivist/model.py
index d1ef6c7..fd1bec3 100644
--- a/archivist/model.py
+++ b/archivist/model.py
@@ -8,7 +8,7 @@ from enum import Enum, unique
from pkg_resources import resource_filename
from .prefixes import query_pseudo_prefix
-from .closure import ClosureTable
+from .peewee_ext import ClosureTable, EnumField
db = SqliteExtDatabase('test.db', pragmas=[('foreign_keys', 'ON')])
db.load_extension(resource_filename(__name__, 'sqlext/closure'))
@@ -31,29 +31,6 @@ class BaseModel(Model):
class Meta:
database = db
-class EnumField(Field):
- db_field = 'enum'
-
- def __init__(self, enum_class, *args, **kwargs):
- super().__init__(*args, **kwargs)
- self.enum_class = enum_class
-
- def _enum_value(self, value):
- return self.enum_class(int(value))
-
- def db_value(self, value):
- if value is None:
- return value
-
- if isinstance(value, self.enum_class):
- return value.value
-
- # force check of enum value
- return self._enum_value(value).value
-
- def python_value(self, value):
- return value if value is None else self._enum_value(value)
-
@table
class Document(BaseModel):
@unique
diff --git a/archivist/closure.py b/archivist/peewee_ext.py
index 01fdc75..9fda66e 100644
--- a/archivist/closure.py
+++ b/archivist/peewee_ext.py
@@ -1,4 +1,37 @@
from playhouse.sqlite_ext import VirtualModel, VirtualIntegerField, VirtualCharField
+from peewee import Field
+
+class EnumField(Field):
+ db_field = 'enum'
+
+ def __init__(self, enum_class, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.enum_class = enum_class
+
+ def _enum_value(self, value):
+ if isinstance(value, str):
+ try:
+ return self.enum_class[value.upper()]
+ except KeyError:
+ pass
+
+ try:
+ return self.enum_class(int(value))
+ except ValueError:
+ raise ValueError("%r is not a valid %s" % (value, self.enum_class.__name__))
+
+ def db_value(self, value):
+ if value is None:
+ return value
+
+ if isinstance(value, self.enum_class):
+ return value.value
+
+ # force check of enum value
+ return self._enum_value(value).value
+
+ def python_value(self, value):
+ return value if value is None else self._enum_value(value)
def ClosureTable(model_class, referencing_class = None, foreign_key=None, id_column = None):
"""Model factory for the transitive closure extension."""