Add validators
This commit is contained in:
parent
b7d7ef9638
commit
64aef7c18b
|
|
@ -1,17 +1,15 @@
|
|||
import inspect
|
||||
import re
|
||||
from collections.abc import Iterable
|
||||
from functools import reduce
|
||||
from operator import ior
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
from tinydb import Query, TinyDB, table
|
||||
from tinydb import TinyDB, table
|
||||
from tinydb.storages import MemoryStorage
|
||||
from tinydb.table import Document
|
||||
|
||||
from grung.exceptions import CircularReferenceError, UniqueConstraintError
|
||||
from grung.types import Record
|
||||
from grung.exceptions import CircularReferenceError
|
||||
from grung.objects import Record
|
||||
from grung.validators import TypeValidator
|
||||
|
||||
|
||||
class RecordTable(table.Table):
|
||||
|
|
@ -26,6 +24,12 @@ class RecordTable(table.Table):
|
|||
|
||||
def insert(self, document):
|
||||
document.before_insert(self.db)
|
||||
|
||||
# check field types before attempting serialization
|
||||
validator = TypeValidator()
|
||||
for field in document._metadata.fields.values():
|
||||
validator.validate(document, field, self.db)
|
||||
|
||||
doc = document.serialize()
|
||||
self._check_constraints(doc)
|
||||
|
||||
|
|
@ -68,29 +72,17 @@ class RecordTable(table.Table):
|
|||
|
||||
def _check_constraints(self, document) -> bool:
|
||||
self._check_for_recursion(document)
|
||||
self._check_unique(document)
|
||||
for field in document._metadata.fields.values():
|
||||
field.validate(document, self.db)
|
||||
|
||||
def _check_for_recursion(self, document) -> bool:
|
||||
ref = document.reference
|
||||
for field in document._metadata.fields.values():
|
||||
if isinstance(field.default, Iterable) and ref in document[field.name]:
|
||||
raise CircularReferenceError(ref, field)
|
||||
raise CircularReferenceError(document, field, ref, "builtin")
|
||||
elif document[field.name] == ref:
|
||||
raise CircularReferenceError(ref, field)
|
||||
|
||||
def _check_unique(self, document) -> bool:
|
||||
matches = []
|
||||
queries = reduce(
|
||||
ior,
|
||||
[
|
||||
Query()[field.name].matches(f"^{document[field.name]}$", flags=re.IGNORECASE)
|
||||
for field in document._metadata.fields.values()
|
||||
if field.unique
|
||||
],
|
||||
)
|
||||
matches = [dict(match) for match in super().search(queries) if match.doc_id != document.doc_id]
|
||||
if matches != []:
|
||||
raise UniqueConstraintError(document, queries, matches)
|
||||
raise CircularReferenceError(document, field, ref, "builtin")
|
||||
return True
|
||||
|
||||
|
||||
class GrungDB(TinyDB):
|
||||
|
|
|
|||
|
|
@ -1,18 +1,21 @@
|
|||
from grung.types import (
|
||||
import re
|
||||
|
||||
from grung.objects import (
|
||||
BackReference,
|
||||
BinaryFilePointer,
|
||||
Collection,
|
||||
DateTime,
|
||||
Dict,
|
||||
Field,
|
||||
Integer,
|
||||
List,
|
||||
Password,
|
||||
Record,
|
||||
RecordDict,
|
||||
String,
|
||||
TextFilePointer,
|
||||
Timestamp,
|
||||
)
|
||||
from grung.validators import LengthValidator, MinMaxValidator, PatternValidator
|
||||
|
||||
|
||||
class User(Record):
|
||||
|
|
@ -20,13 +23,13 @@ class User(Record):
|
|||
def fields(cls):
|
||||
return [
|
||||
*super().fields(),
|
||||
Field("name", primary_key=True),
|
||||
Integer("number", default=0),
|
||||
Field("email", unique=True),
|
||||
String("name", primary_key=True, validators=[LengthValidator(min=3, max=30)]),
|
||||
Integer("number", default=0, validators=[MinMaxValidator(min=0, max=255)]),
|
||||
String("email", unique=True, validators=[PatternValidator(re.compile(r"[^@]+@[\w\-\.]+$"))]),
|
||||
Password("password"),
|
||||
DateTime("created"),
|
||||
Timestamp("last_updated"),
|
||||
BackReference("groups", Group),
|
||||
BackReference("groups", value_type=Group),
|
||||
]
|
||||
|
||||
|
||||
|
|
@ -35,10 +38,10 @@ class Group(Record):
|
|||
def fields(cls):
|
||||
return [
|
||||
*super().fields(),
|
||||
Field("name", primary_key=True),
|
||||
Collection("members", User),
|
||||
Collection("groups", Group),
|
||||
BackReference("parent", Group),
|
||||
String("name", primary_key=True),
|
||||
Collection("members", member_type=User),
|
||||
Collection("groups", member_type=Group),
|
||||
BackReference("parent", value_type=Group),
|
||||
]
|
||||
|
||||
|
||||
|
|
@ -47,10 +50,10 @@ class Album(Record):
|
|||
def fields(cls):
|
||||
inherited = [f for f in super().fields() if f.name != "name"]
|
||||
return inherited + [
|
||||
Field("name"),
|
||||
String("name"),
|
||||
Dict("credits"),
|
||||
List("tracks"),
|
||||
BackReference("artist", Artist),
|
||||
BackReference("artist", value_type=Artist),
|
||||
BinaryFilePointer("cover", extension=".jpg"),
|
||||
TextFilePointer("review"),
|
||||
]
|
||||
|
|
@ -60,4 +63,4 @@ class Artist(User):
|
|||
@classmethod
|
||||
def fields(cls):
|
||||
inherited = [f for f in super().fields() if f.name != "name"]
|
||||
return inherited + [Field("name"), RecordDict("albums", Album)]
|
||||
return inherited + [String("name"), RecordDict("albums", member_type=Album)]
|
||||
|
|
|
|||
|
|
@ -1,42 +1,92 @@
|
|||
class UniqueConstraintError(Exception):
|
||||
class ValidationError(Exception):
|
||||
"""
|
||||
Thrown when a record's field does not meet validation criteria.
|
||||
"""
|
||||
|
||||
messages = []
|
||||
template = (
|
||||
"\n"
|
||||
" * Record: {_record}\n"
|
||||
" * Field: {_field}\n"
|
||||
" * Value: {_value}\n"
|
||||
" * Validator: {_validator}\n"
|
||||
"\n"
|
||||
"{_messages}"
|
||||
)
|
||||
|
||||
def __init__(self, record, field, validator, messages=[], **kwargs):
|
||||
super().__init__(
|
||||
self.template.format(
|
||||
_record=dict(record),
|
||||
_field=field,
|
||||
_value=record[field.name],
|
||||
_validator=validator,
|
||||
_messages="\n".join(messages or self.__class__.messages),
|
||||
**kwargs,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class InvalidFieldTypeError(ValidationError):
|
||||
"""
|
||||
Thrown when a document's field value does not match the field value_type.
|
||||
"""
|
||||
|
||||
messages = ["The value of the field is not an instance of the field's value_type."]
|
||||
|
||||
|
||||
class UniqueConstraintError(ValidationError):
|
||||
"""
|
||||
Thrown when a db write operation cannot complete due to a field's unique constraint.
|
||||
"""
|
||||
|
||||
def __init__(self, document, query, collisions):
|
||||
def __init__(self, record, field, validator, query, matches):
|
||||
super().__init__(
|
||||
"\n"
|
||||
f" * Record: {dict(document)}\n"
|
||||
f" * Query: {query}\n"
|
||||
f" * Error: Unique constraint failure\n"
|
||||
" * The record matches the following existing records:\n\n" + "\n".join(str(c) for c in collisions)
|
||||
record,
|
||||
field,
|
||||
validator,
|
||||
messages=[
|
||||
f"Query: {query}",
|
||||
"The record matches the following existing records:\n\n" + "\n".join(str(m) for m in matches),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class PointerReferenceError(Exception):
|
||||
"""
|
||||
Thrown when a document field containing a document could not be resolve to an existing record in the database.
|
||||
"""
|
||||
|
||||
def __init__(self, reference):
|
||||
super().__init__(
|
||||
"\n"
|
||||
f" * Reference: {reference}\n"
|
||||
f" * Error: Invalid Pointer\n"
|
||||
" * This collection member does not refer an existing record. Do you need to save it first?"
|
||||
)
|
||||
|
||||
|
||||
class CircularReferenceError(Exception):
|
||||
class CircularReferenceError(ValidationError):
|
||||
"""
|
||||
Thrown when a record contains a reference to itself.
|
||||
"""
|
||||
|
||||
def __init__(self, reference, field):
|
||||
super().__init__(
|
||||
"\n"
|
||||
f" * Reference: {reference}\n"
|
||||
f" * Field: {field.name}\n"
|
||||
f" * Error: Circular Reference\n"
|
||||
f" * This record contains a reference to itself. This will lead to infinite recursion."
|
||||
)
|
||||
messages = ["This record contains a reference to itself. This will lead to infinite recursion."]
|
||||
|
||||
|
||||
class MalformedPointerError(ValidationError):
|
||||
"""
|
||||
Thrown when a Pointer's value is not a valid reference string.
|
||||
"""
|
||||
|
||||
messages = ["A Pointer's value must follow the format 'TABLE_NAME::PRIMARY_KEY_NAME::PRIMARY_KEY_VALUE'."]
|
||||
|
||||
|
||||
class PointerReferenceError(Exception):
|
||||
"""
|
||||
Thrown when a record field containing a record could not be resolve to an existing record in the database.
|
||||
"""
|
||||
|
||||
|
||||
class InvalidLengthError(ValidationError):
|
||||
"""
|
||||
Thrown when a field does not meet its length constraint.
|
||||
"""
|
||||
|
||||
|
||||
class InvalidSizeError(ValidationError):
|
||||
"""
|
||||
Thrown when a field's size is too large or too small.
|
||||
"""
|
||||
|
||||
|
||||
class PatternMatchError(ValidationError):
|
||||
"""
|
||||
Thrown when a field does not match the specified pattern.
|
||||
"""
|
||||
|
|
|
|||
448
src/grung/objects.py
Normal file
448
src/grung/objects.py
Normal file
|
|
@ -0,0 +1,448 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import hmac
|
||||
import os
|
||||
import re
|
||||
import typing
|
||||
from collections import namedtuple
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
import nanoid
|
||||
from tinydb import TinyDB, where
|
||||
|
||||
import grung.types
|
||||
from grung.exceptions import PointerReferenceError
|
||||
from grung.validators import PointerReferenceValidator, UniqueValidator
|
||||
|
||||
Metadata = namedtuple("Metadata", ["table", "fields", "backrefs", "primary_key"])
|
||||
|
||||
|
||||
@dataclass
|
||||
class Field(grung.types.Field):
|
||||
"""
|
||||
Represents a single field in a Record.
|
||||
"""
|
||||
|
||||
name: str
|
||||
default: str = None
|
||||
unique: bool = False
|
||||
primary_key: bool = False
|
||||
validators: list = field(default_factory=lambda: [])
|
||||
|
||||
value_type = str
|
||||
|
||||
def before_insert(self, value: value_type, db: TinyDB, record: Record) -> None:
|
||||
pass
|
||||
|
||||
def after_insert(self, db: TinyDB, record: Record) -> None:
|
||||
pass
|
||||
|
||||
def serialize(self, value: value_type, record: Record | None = None) -> str:
|
||||
if value is not None:
|
||||
return str(value)
|
||||
|
||||
def deserialize(self, value: str, db: TinyDB, recurse: bool = False) -> value_type:
|
||||
return value
|
||||
|
||||
def validate(self, record: Record, db: TinyDB):
|
||||
if self.unique:
|
||||
UniqueValidator().validate(record, self, db)
|
||||
for validator in self.validators:
|
||||
validator.validate(record, self, db)
|
||||
|
||||
|
||||
class Record(grung.types.Record):
|
||||
"""
|
||||
Base type for a single database record.
|
||||
"""
|
||||
|
||||
def __init__(self, raw_doc: dict = {}, doc_id: int = None, **params):
|
||||
self.doc_id = doc_id
|
||||
fields = self.__class__.fields()
|
||||
|
||||
pkey = [field for field in fields if field.primary_key]
|
||||
if len(pkey) > 1:
|
||||
raise Exception(f"Cannnot have more than one primary key: {pkey}")
|
||||
elif pkey:
|
||||
pkey = pkey[0]
|
||||
else:
|
||||
# 1% collision rate at ~2M records
|
||||
pkey = Field("uid", default=nanoid.generate(size=8), primary_key=True)
|
||||
fields.append(pkey)
|
||||
|
||||
pkey.unique = True
|
||||
|
||||
self._metadata = Metadata(
|
||||
table=self.__class__.__name__,
|
||||
primary_key=pkey.name,
|
||||
fields={f.name: f for f in fields},
|
||||
backrefs=lambda value_type: (
|
||||
field for field in fields if type(field) == BackReference and field.value_type == value_type
|
||||
),
|
||||
)
|
||||
super().__init__(dict({field.name: field.default for field in fields}, **raw_doc, **params))
|
||||
|
||||
@classmethod
|
||||
def fields(cls):
|
||||
return []
|
||||
|
||||
def serialize(self):
|
||||
"""
|
||||
Serialize every field on the record
|
||||
"""
|
||||
rec = {}
|
||||
for name, _field in self._metadata.fields.items():
|
||||
rec[name] = _field.serialize(self[name], record=self) if isinstance(_field, Field) else _field
|
||||
return self.__class__(rec, doc_id=self.doc_id)
|
||||
|
||||
def deserialize(self, db, recurse: bool = True):
|
||||
"""
|
||||
Deserialize every field on the record
|
||||
"""
|
||||
rec = {}
|
||||
for name, _field in self._metadata.fields.items():
|
||||
rec[name] = _field.deserialize(self[name], db, recurse=recurse)
|
||||
return self.__class__(rec, doc_id=self.doc_id)
|
||||
|
||||
def before_insert(self, db: TinyDB) -> None:
|
||||
for name, _field in self._metadata.fields.items():
|
||||
_field.before_insert(self[name], db, self)
|
||||
|
||||
def after_insert(self, db: TinyDB) -> None:
|
||||
for name, _field in self._metadata.fields.items():
|
||||
_field.after_insert(db, self)
|
||||
|
||||
def update(self, **data):
|
||||
for key, value in data.items():
|
||||
self[key] = value
|
||||
|
||||
@property
|
||||
def reference(self):
|
||||
return Pointer.reference(self)
|
||||
|
||||
@property
|
||||
def path(self):
|
||||
return Path(self._metadata.table) / self[self._metadata.primary_key]
|
||||
|
||||
def __setattr__(self, key, value):
|
||||
if key in self:
|
||||
self[key] = value
|
||||
super().__setattr__(key, value)
|
||||
|
||||
def __getattr__(self, attr_name):
|
||||
if attr_name in self:
|
||||
return self.get(attr_name)
|
||||
raise AttributeError(f"No such attribute: {attr_name}")
|
||||
|
||||
def __hash__(self):
|
||||
return hash(str(dict(self)))
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"{self.__class__.__name__}[{self.doc_id}]("
|
||||
+ ", ".join([f"{key}={val}" for (key, val) in self.items()])
|
||||
+ ")"
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class String(Field):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class Integer(Field):
|
||||
value_type = int
|
||||
default: int = 0
|
||||
|
||||
def deserialize(self, value: str, db: TinyDB, recurse: bool = False) -> value_type:
|
||||
return int(value)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Dict(Field):
|
||||
default: dict = field(default_factory=lambda: {})
|
||||
|
||||
value_type = dict
|
||||
|
||||
def serialize(self, values: dict, record: Record | None = None) -> Dict[(str, str)]:
|
||||
return dict((key, str(value)) for key, value in values.items())
|
||||
|
||||
def deserialize(self, values: dict, db: TinyDB, recurse: bool = False) -> Dict[(str, str)]:
|
||||
return values
|
||||
|
||||
|
||||
@dataclass
|
||||
class List(Field):
|
||||
default: list = field(default_factory=lambda: [])
|
||||
|
||||
value_type = list
|
||||
|
||||
def serialize(self, values: list, record: Record | None = None) -> Dict[(str, str)]:
|
||||
return values
|
||||
|
||||
def deserialize(self, values: list, db: TinyDB, recurse: bool = False) -> typing.List[str]:
|
||||
return values
|
||||
|
||||
|
||||
@dataclass
|
||||
class DateTime(Field):
|
||||
default: datetime = datetime.utcfromtimestamp(0)
|
||||
|
||||
value_type = datetime
|
||||
|
||||
def serialize(self, value: value_type, record: Record | None = None) -> str:
|
||||
return (value - datetime.utcfromtimestamp(0)).total_seconds()
|
||||
|
||||
def deserialize(self, value: str, db: TinyDB, recurse: bool = False) -> value_type:
|
||||
return datetime.utcfromtimestamp(int(value))
|
||||
|
||||
def before_insert(self, value: value_type, db: TinyDB, record: Record) -> None:
|
||||
if not value:
|
||||
record[self.name] = datetime.utcnow().replace(microsecond=0)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Timestamp(DateTime):
|
||||
value_type = datetime
|
||||
|
||||
def before_insert(self, value: value_type, db: TinyDB, record: Record) -> None:
|
||||
super().before_insert(None, db, record)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Password(Field):
|
||||
value_type = str
|
||||
default: str = None
|
||||
|
||||
# Relatively weak. Consider using stronger initial values in production applications.
|
||||
salt_size = 4
|
||||
digest_size = 16
|
||||
|
||||
@classmethod
|
||||
def is_digest(cls, passwd: str):
|
||||
if not passwd:
|
||||
return False
|
||||
offset = 2 * cls.salt_size # each byte is 2 hex chars
|
||||
try:
|
||||
if passwd[offset] != ":":
|
||||
return False
|
||||
digest = passwd[(offset + 1) :] # noqa
|
||||
if len(digest) != cls.digest_size * 2:
|
||||
return False
|
||||
return re.match(r"^[0-9a-f]+$", digest)
|
||||
except IndexError:
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def get_digest(cls, passwd: str, salt: bytes = None):
|
||||
if not salt:
|
||||
salt = os.urandom(cls.salt_size)
|
||||
digest = hashlib.blake2b(passwd.encode(), digest_size=cls.digest_size, salt=salt).hexdigest()
|
||||
return digest, salt.hex()
|
||||
|
||||
@classmethod
|
||||
def compare(cls, passwd: value_type, stored: value_type):
|
||||
stored_salt, stored_digest = stored.split(":")
|
||||
input_digest, input_salt = cls.get_digest(passwd, bytes.fromhex(stored_salt))
|
||||
return hmac.compare_digest(input_digest, stored_digest)
|
||||
|
||||
def before_insert(self, value: value_type, db: TinyDB, record: Record) -> None:
|
||||
if value and not self.__class__.is_digest(value):
|
||||
digest, salt = self.__class__.get_digest(value)
|
||||
record[self.name] = f"{salt}:{digest}"
|
||||
|
||||
|
||||
@dataclass
|
||||
class Pointer(Field):
|
||||
"""
|
||||
Store a string reference to a record.
|
||||
"""
|
||||
|
||||
name: str = ""
|
||||
value_type: grung.types.Record = Record
|
||||
|
||||
def serialize(self, value: value_type | str, record: Record | None = None) -> str:
|
||||
return Pointer.reference(value)
|
||||
|
||||
def deserialize(self, value: str, db: TinyDB, recurse: bool = True) -> value_type:
|
||||
return Pointer.dereference(value, db, recurse)
|
||||
|
||||
@classmethod
|
||||
def reference(cls, value: Record | str):
|
||||
if isinstance(value, str):
|
||||
PointerReferenceValidator().validate_string(value)
|
||||
return value
|
||||
|
||||
if value:
|
||||
return f"{value._metadata.table}::{value._metadata.primary_key}::{value[value._metadata.primary_key]}"
|
||||
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def dereference(cls, value: str, db: TinyDB, recurse: bool = True):
|
||||
if not value:
|
||||
return
|
||||
elif type(value) == str:
|
||||
table_name, pkey, pval = value.split("::")
|
||||
if pval:
|
||||
table = db.table(table_name)
|
||||
rec = table.get(where(pkey) == pval, recurse=recurse)
|
||||
if not rec:
|
||||
raise PointerReferenceError(f"Expected a {table_name} with {pkey}=={pval} but did not find one!")
|
||||
return rec
|
||||
return value
|
||||
|
||||
|
||||
@dataclass
|
||||
class BackReference(Pointer):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class BinaryFilePointer(Field):
|
||||
"""
|
||||
Write the contents of this field to disk and store the path in the db.
|
||||
"""
|
||||
|
||||
name: str
|
||||
extension: str = ".blob"
|
||||
|
||||
value_type = bytes
|
||||
|
||||
def relpath(self, record):
|
||||
return Path(record._metadata.table) / record[record._metadata.primary_key] / f"{self.name}{self.extension}"
|
||||
|
||||
def reference(self, record):
|
||||
return f"/::{self.relpath(record)}"
|
||||
|
||||
def dereference(self, reference, db):
|
||||
relpath = reference.replace("/::", "", 1)
|
||||
try:
|
||||
return (db.path / relpath).read_bytes()
|
||||
except FileNotFoundError:
|
||||
return None
|
||||
|
||||
def serialize(self, value: value_type | str, record: Record | None = None) -> str:
|
||||
return self.reference(record)
|
||||
|
||||
def deserialize(self, value: str, db: TinyDB, recurse: bool = True) -> value_type:
|
||||
if not value:
|
||||
return None
|
||||
return self.dereference(value, db)
|
||||
|
||||
def prepare(self, data: value_type):
|
||||
"""
|
||||
Return bytes to be written to disk
|
||||
"""
|
||||
if not data:
|
||||
return
|
||||
if not isinstance(data, self.value_type):
|
||||
return data.encode()
|
||||
return data
|
||||
|
||||
def before_insert(self, value: value_type, db: TinyDB, record: Record) -> None:
|
||||
if not value:
|
||||
return
|
||||
relpath = self.relpath(record)
|
||||
path = db.path / relpath
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
path.write_bytes(self.prepare(value))
|
||||
|
||||
|
||||
@dataclass
|
||||
class TextFilePointer(BinaryFilePointer):
|
||||
"""
|
||||
Write the contents of this field to disk and store the path in the db.
|
||||
"""
|
||||
|
||||
name: str
|
||||
extension: str = ".txt"
|
||||
|
||||
value_type = str
|
||||
|
||||
def prepare(self, data: value_type):
|
||||
if isinstance(data, bytes):
|
||||
return data
|
||||
return str(data).encode()
|
||||
|
||||
def deserialize(self, value: str, db: TinyDB, recurse: bool = True) -> value_type:
|
||||
if not value:
|
||||
return None
|
||||
buf = super().deserialize(value, db)
|
||||
return buf.decode() if buf else None
|
||||
|
||||
|
||||
@dataclass
|
||||
class Collection(Field):
|
||||
"""
|
||||
A collection of pointers.
|
||||
"""
|
||||
|
||||
default: typing.List[value_type] = field(default_factory=lambda: [])
|
||||
member_type: type = Record
|
||||
|
||||
value_type = list
|
||||
|
||||
def serialize(self, values: typing.List[value_type], record: Record | None = None) -> typing.List[str]:
|
||||
return [Pointer.reference(val) for val in values]
|
||||
|
||||
def deserialize(self, values: typing.List[str], db: TinyDB, recurse: bool = False) -> typing.List[value_type]:
|
||||
"""
|
||||
Recursively deserialize the objects in this collection
|
||||
"""
|
||||
recs = []
|
||||
if not recurse:
|
||||
return values
|
||||
for val in values:
|
||||
recs.append(Pointer.dereference(val, db=db, recurse=False))
|
||||
return recs
|
||||
|
||||
def after_insert(self, db: TinyDB, record: Record) -> None:
|
||||
"""
|
||||
Populate any backreferences in the members of this collection with the parent record's uid.
|
||||
"""
|
||||
if not record[self.name]:
|
||||
return
|
||||
|
||||
for member in record[self.name]:
|
||||
target = Pointer.dereference(member, db=db, recurse=False)
|
||||
for backref in target._metadata.backrefs(type(record)):
|
||||
target[backref.name] = record
|
||||
db.save(target)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RecordDict(Field):
|
||||
default: typing.Dict[(str, Record)] = field(default_factory=lambda: {})
|
||||
member_type: type = Record
|
||||
|
||||
value_type = dict
|
||||
|
||||
def serialize(
|
||||
self, values: typing.Dict[(str, value_type)], record: Record | None = None
|
||||
) -> typing.Dict[(str, str)]:
|
||||
return dict((key, Pointer.reference(val)) for (key, val) in values.items())
|
||||
|
||||
def deserialize(
|
||||
self, values: typing.Dict[(str, str)], db: TinyDB, recurse: bool = False
|
||||
) -> typing.Dict[(str, value_type)]:
|
||||
if not recurse:
|
||||
return values
|
||||
return dict((key, Pointer.dereference(val, db=db, recurse=False)) for (key, val) in values.items())
|
||||
|
||||
def after_insert(self, db: TinyDB, record: Record) -> None:
|
||||
"""
|
||||
Populate any backreferences in the members of this mapping with the parent record's uid.
|
||||
"""
|
||||
if not record[self.name]:
|
||||
return
|
||||
|
||||
for key, pointer in record[self.name].items():
|
||||
target = Pointer.dereference(pointer, db=db, recurse=False)
|
||||
for backref in target._metadata.backrefs(type(record)):
|
||||
target[backref.name] = record
|
||||
db.save(target)
|
||||
|
|
@ -1,420 +1,9 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import hmac
|
||||
import os
|
||||
import re
|
||||
import typing
|
||||
from collections import namedtuple
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
import nanoid
|
||||
from tinydb import TinyDB, where
|
||||
|
||||
from grung.exceptions import PointerReferenceError
|
||||
|
||||
Metadata = namedtuple("Metadata", ["table", "fields", "backrefs", "primary_key"])
|
||||
|
||||
|
||||
@dataclass
|
||||
class Field:
|
||||
"""
|
||||
Represents a single field in a Record.
|
||||
"""
|
||||
|
||||
name: str
|
||||
value_type: type = str
|
||||
default: str = None
|
||||
unique: bool = False
|
||||
primary_key: bool = False
|
||||
|
||||
def before_insert(self, value: value_type, db: TinyDB, record: Record) -> None:
|
||||
pass
|
||||
|
||||
def after_insert(self, db: TinyDB, record: Record) -> None:
|
||||
pass
|
||||
|
||||
def serialize(self, value: value_type, record: Record | None = None) -> str:
|
||||
if value is not None:
|
||||
return str(value)
|
||||
|
||||
def deserialize(self, value: str, db: TinyDB, recurse: bool = False) -> value_type:
|
||||
return value
|
||||
|
||||
|
||||
@dataclass
|
||||
class Integer(Field):
|
||||
value_type = int
|
||||
default: int = 0
|
||||
|
||||
def deserialize(self, value: str, db: TinyDB, recurse: bool = False) -> value_type:
|
||||
return int(value)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Dict(Field):
|
||||
value_type: type = str
|
||||
default: dict = field(default_factory=lambda: {})
|
||||
|
||||
def serialize(self, values: dict, record: Record | None = None) -> Dict[(str, str)]:
|
||||
return dict((key, str(value)) for key, value in values.items())
|
||||
|
||||
def deserialize(self, values: dict, db: TinyDB, recurse: bool = False) -> Dict[(str, str)]:
|
||||
return values
|
||||
|
||||
|
||||
@dataclass
|
||||
class List(Field):
|
||||
value_type: type = list
|
||||
default: list = field(default_factory=lambda: [])
|
||||
|
||||
def serialize(self, values: list, record: Record | None = None) -> Dict[(str, str)]:
|
||||
return values
|
||||
|
||||
def deserialize(self, values: list, db: TinyDB, recurse: bool = False) -> typing.List[str]:
|
||||
return values
|
||||
|
||||
|
||||
@dataclass
|
||||
class DateTime(Field):
|
||||
value_type: datetime
|
||||
default: datetime = datetime.utcfromtimestamp(0)
|
||||
|
||||
def serialize(self, value: value_type, record: Record | None = None) -> str:
|
||||
return (value - datetime.utcfromtimestamp(0)).total_seconds()
|
||||
|
||||
def deserialize(self, value: str, db: TinyDB, recurse: bool = False) -> value_type:
|
||||
return datetime.utcfromtimestamp(int(value))
|
||||
|
||||
def before_insert(self, value: value_type, db: TinyDB, record: Record) -> None:
|
||||
if not value:
|
||||
record[self.name] = datetime.utcnow().replace(microsecond=0)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Timestamp(DateTime):
|
||||
value_type: datetime
|
||||
|
||||
def before_insert(self, value: value_type, db: TinyDB, record: Record) -> None:
|
||||
super().before_insert(None, db, record)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Password(Field):
|
||||
value_type = str
|
||||
default: str = None
|
||||
|
||||
# Relatively weak. Consider using stronger initial values in production applications.
|
||||
salt_size = 4
|
||||
digest_size = 16
|
||||
|
||||
@classmethod
|
||||
def is_digest(cls, passwd: str):
|
||||
if not passwd:
|
||||
return False
|
||||
offset = 2 * cls.salt_size # each byte is 2 hex chars
|
||||
try:
|
||||
if passwd[offset] != ":":
|
||||
return False
|
||||
digest = passwd[(offset + 1) :] # noqa
|
||||
if len(digest) != cls.digest_size * 2:
|
||||
return False
|
||||
return re.match(r"^[0-9a-f]+$", digest)
|
||||
except IndexError:
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def get_digest(cls, passwd: str, salt: bytes = None):
|
||||
if not salt:
|
||||
salt = os.urandom(cls.salt_size)
|
||||
digest = hashlib.blake2b(passwd.encode(), digest_size=cls.digest_size, salt=salt).hexdigest()
|
||||
return digest, salt.hex()
|
||||
|
||||
@classmethod
|
||||
def compare(cls, passwd: value_type, stored: value_type):
|
||||
stored_salt, stored_digest = stored.split(":")
|
||||
input_digest, input_salt = cls.get_digest(passwd, bytes.fromhex(stored_salt))
|
||||
return hmac.compare_digest(input_digest, stored_digest)
|
||||
|
||||
def before_insert(self, value: value_type, db: TinyDB, record: Record) -> None:
|
||||
if value and not self.__class__.is_digest(value):
|
||||
digest, salt = self.__class__.get_digest(value)
|
||||
record[self.name] = f"{salt}:{digest}"
|
||||
|
||||
|
||||
class Record(typing.Dict[(str, Field)]):
|
||||
"""
|
||||
Base type for a single database record.
|
||||
"""
|
||||
|
||||
def __init__(self, raw_doc: dict = {}, doc_id: int = None, **params):
|
||||
self.doc_id = doc_id
|
||||
fields = self.__class__.fields()
|
||||
|
||||
pkey = [field for field in fields if field.primary_key]
|
||||
if len(pkey) > 1:
|
||||
raise Exception(f"Cannnot have more than one primary key: {pkey}")
|
||||
elif pkey:
|
||||
pkey = pkey[0]
|
||||
else:
|
||||
# 1% collision rate at ~2M records
|
||||
pkey = Field("uid", default=nanoid.generate(size=8), primary_key=True)
|
||||
fields.append(pkey)
|
||||
|
||||
pkey.unique = True
|
||||
|
||||
self._metadata = Metadata(
|
||||
table=self.__class__.__name__,
|
||||
primary_key=pkey.name,
|
||||
fields={f.name: f for f in fields},
|
||||
backrefs=lambda value_type: (
|
||||
field for field in fields if type(field) == BackReference and field.value_type == value_type
|
||||
),
|
||||
)
|
||||
super().__init__(dict({field.name: field.default for field in fields}, **raw_doc, **params))
|
||||
|
||||
@classmethod
|
||||
def fields(cls):
|
||||
return []
|
||||
|
||||
def serialize(self):
|
||||
"""
|
||||
Serialize every field on the record
|
||||
"""
|
||||
rec = {}
|
||||
for name, _field in self._metadata.fields.items():
|
||||
rec[name] = _field.serialize(self[name], record=self) if isinstance(_field, Field) else _field
|
||||
return self.__class__(rec, doc_id=self.doc_id)
|
||||
|
||||
def deserialize(self, db, recurse: bool = True):
|
||||
"""
|
||||
Deserialize every field on the record
|
||||
"""
|
||||
rec = {}
|
||||
for name, _field in self._metadata.fields.items():
|
||||
rec[name] = _field.deserialize(self[name], db, recurse=recurse)
|
||||
return self.__class__(rec, doc_id=self.doc_id)
|
||||
|
||||
def before_insert(self, db: TinyDB) -> None:
|
||||
for name, _field in self._metadata.fields.items():
|
||||
_field.before_insert(self[name], db, self)
|
||||
|
||||
def after_insert(self, db: TinyDB) -> None:
|
||||
for name, _field in self._metadata.fields.items():
|
||||
_field.after_insert(db, self)
|
||||
|
||||
@property
|
||||
def reference(self):
|
||||
return Pointer.reference(self)
|
||||
|
||||
@property
|
||||
def path(self):
|
||||
return Path(self._metadata.table) / self[self._metadata.primary_key]
|
||||
|
||||
def __setattr__(self, key, value):
|
||||
if key in self:
|
||||
self[key] = value
|
||||
super().__setattr__(key, value)
|
||||
|
||||
def __getattr__(self, attr_name):
|
||||
if attr_name in self:
|
||||
return self.get(attr_name)
|
||||
raise AttributeError(f"No such attribute: {attr_name}")
|
||||
|
||||
def __hash__(self):
|
||||
return hash(str(dict(self)))
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"{self.__class__.__name__}[{self.doc_id}]("
|
||||
+ ", ".join([f"{key}={val}" for (key, val) in self.items()])
|
||||
+ ")"
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Pointer(Field):
|
||||
"""
|
||||
Store a string reference to a record.
|
||||
"""
|
||||
|
||||
name: str = ""
|
||||
value_type: type = Record
|
||||
|
||||
def serialize(self, value: value_type | str, record: Record | None = None) -> str:
|
||||
return Pointer.reference(value)
|
||||
|
||||
def deserialize(self, value: str, db: TinyDB, recurse: bool = True) -> value_type:
|
||||
return Pointer.dereference(value, db, recurse)
|
||||
|
||||
@classmethod
|
||||
def reference(cls, value: Record | str):
|
||||
# XXX This could be smarter
|
||||
if isinstance(value, str):
|
||||
if "::" not in value:
|
||||
raise PointerReferenceError("Value {value} does not look like a reference!")
|
||||
return value
|
||||
if value:
|
||||
return f"{value._metadata.table}::{value._metadata.primary_key}::{value[value._metadata.primary_key]}"
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def dereference(cls, value: str, db: TinyDB, recurse: bool = True):
|
||||
if not value:
|
||||
return
|
||||
elif type(value) == str:
|
||||
table_name, pkey, pval = value.split("::")
|
||||
if pval:
|
||||
table = db.table(table_name)
|
||||
rec = table.get(where(pkey) == pval, recurse=recurse)
|
||||
if not rec:
|
||||
raise PointerReferenceError(f"Expected a {table_name} with {pkey}=={pval} but did not find one!")
|
||||
return rec
|
||||
return value
|
||||
|
||||
|
||||
@dataclass
|
||||
class BackReference(Pointer):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class BinaryFilePointer(Field):
|
||||
"""
|
||||
Write the contents of this field to disk and store the path in the db.
|
||||
"""
|
||||
|
||||
name: str
|
||||
value_type: type = bytes
|
||||
extension: str = ".blob"
|
||||
|
||||
def relpath(self, record):
|
||||
return Path(record._metadata.table) / record[record._metadata.primary_key] / f"{self.name}{self.extension}"
|
||||
|
||||
def reference(self, record):
|
||||
return f"/::{self.relpath(record)}"
|
||||
|
||||
def dereference(self, reference, db):
|
||||
relpath = reference.replace("/::", "", 1)
|
||||
try:
|
||||
return (db.path / relpath).read_bytes()
|
||||
except FileNotFoundError:
|
||||
return None
|
||||
|
||||
def serialize(self, value: value_type | str, record: Record | None = None) -> str:
|
||||
return self.reference(record)
|
||||
|
||||
def deserialize(self, value: str, db: TinyDB, recurse: bool = True) -> value_type:
|
||||
if not value:
|
||||
return None
|
||||
return self.dereference(value, db)
|
||||
|
||||
def prepare(self, data: value_type):
|
||||
"""
|
||||
Return bytes to be written to disk
|
||||
"""
|
||||
if not data:
|
||||
return
|
||||
if not isinstance(data, self.value_type):
|
||||
return data.encode()
|
||||
return data
|
||||
|
||||
def before_insert(self, value: value_type, db: TinyDB, record: Record) -> None:
|
||||
if not value:
|
||||
return
|
||||
relpath = self.relpath(record)
|
||||
path = db.path / relpath
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
path.write_bytes(self.prepare(value))
|
||||
|
||||
|
||||
@dataclass
|
||||
class TextFilePointer(BinaryFilePointer):
|
||||
"""
|
||||
Write the contents of this field to disk and store the path in the db.
|
||||
"""
|
||||
|
||||
name: str
|
||||
value_type: type = str
|
||||
extension: str = ".txt"
|
||||
|
||||
def prepare(self, data: value_type):
|
||||
if isinstance(data, bytes):
|
||||
return data
|
||||
return str(data).encode()
|
||||
|
||||
def deserialize(self, value: str, db: TinyDB, recurse: bool = True) -> value_type:
|
||||
if not value:
|
||||
return None
|
||||
buf = super().deserialize(value, db)
|
||||
return buf.decode() if buf else None
|
||||
|
||||
|
||||
@dataclass
|
||||
class Collection(Field):
|
||||
"""
|
||||
A collection of pointers.
|
||||
"""
|
||||
|
||||
value_type: type = Record
|
||||
default: typing.List[value_type] = field(default_factory=lambda: [])
|
||||
|
||||
def serialize(self, values: typing.List[value_type], record: Record | None = None) -> typing.List[str]:
|
||||
return [Pointer.reference(val) for val in values]
|
||||
|
||||
def deserialize(self, values: typing.List[str], db: TinyDB, recurse: bool = False) -> typing.List[value_type]:
|
||||
"""
|
||||
Recursively deserialize the objects in this collection
|
||||
"""
|
||||
recs = []
|
||||
if not recurse:
|
||||
return values
|
||||
for val in values:
|
||||
recs.append(Pointer.dereference(val, db=db, recurse=False))
|
||||
return recs
|
||||
|
||||
def after_insert(self, db: TinyDB, record: Record) -> None:
|
||||
"""
|
||||
Populate any backreferences in the members of this collection with the parent record's uid.
|
||||
"""
|
||||
if not record[self.name]:
|
||||
return
|
||||
|
||||
for member in record[self.name]:
|
||||
target = Pointer.dereference(member, db=db, recurse=False)
|
||||
for backref in target._metadata.backrefs(type(record)):
|
||||
target[backref.name] = record
|
||||
db.save(target)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RecordDict(Field):
|
||||
value_type: type = Record
|
||||
default: typing.Dict[(str, Record)] = field(default_factory=lambda: {})
|
||||
|
||||
def serialize(
|
||||
self, values: typing.Dict[(str, value_type)], record: Record | None = None
|
||||
) -> typing.Dict[(str, str)]:
|
||||
return dict((key, Pointer.reference(val)) for (key, val) in values.items())
|
||||
|
||||
def deserialize(
|
||||
self, values: typing.Dict[(str, str)], db: TinyDB, recurse: bool = False
|
||||
) -> typing.Dict[(str, value_type)]:
|
||||
if not recurse:
|
||||
return values
|
||||
return dict((key, Pointer.dereference(val, db=db, recurse=False)) for (key, val) in values.items())
|
||||
|
||||
def after_insert(self, db: TinyDB, record: Record) -> None:
|
||||
"""
|
||||
Populate any backreferences in the members of this mapping with the parent record's uid.
|
||||
"""
|
||||
if not record[self.name]:
|
||||
return
|
||||
|
||||
for key, pointer in record[self.name].items():
|
||||
target = Pointer.dereference(pointer, db=db, recurse=False)
|
||||
for backref in target._metadata.backrefs(type(record)):
|
||||
target[backref.name] = record
|
||||
db.save(target)
|
||||
class Record(typing.Dict[(str, Field)]):
|
||||
pass
|
||||
|
|
|
|||
199
src/grung/validators.py
Normal file
199
src/grung/validators.py
Normal file
|
|
@ -0,0 +1,199 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
|
||||
from tinydb import Query, TinyDB
|
||||
|
||||
from grung.exceptions import (
|
||||
InvalidFieldTypeError,
|
||||
InvalidLengthError,
|
||||
InvalidSizeError,
|
||||
MalformedPointerError,
|
||||
PatternMatchError,
|
||||
UniqueConstraintError,
|
||||
ValidationError,
|
||||
)
|
||||
from grung.types import Field, Record
|
||||
|
||||
|
||||
@dataclass
|
||||
class Validator:
|
||||
def validate(self, record: Record, field: Field, db: TinyDB = None) -> bool:
|
||||
raise ValidationError(record, field, self)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TypeValidator:
|
||||
def validate_list(self, record: Record, field: Field, db: TinyDB = None) -> bool:
|
||||
messages = []
|
||||
for i in range(len(record[field.name])):
|
||||
member = record[field.name][i]
|
||||
if isinstance(member, str):
|
||||
class_name = field.member_type.__name__
|
||||
try:
|
||||
return PointerReferenceValidator().validate_string(member, class_name)
|
||||
except MalformedPointerError as e:
|
||||
messages.append(str(e))
|
||||
elif not isinstance(member, field.member_type):
|
||||
messages.append(f"{field.name}[{i}] must be a {field.member_type}, not a {type(member)}.")
|
||||
if messages:
|
||||
raise InvalidFieldTypeError(record, field, self, messages=messages)
|
||||
|
||||
def validate_dict(self, record: Record, field: Field, db: TinyDB = None) -> bool:
|
||||
for key, member in record[field.name].items():
|
||||
if not isinstance(member, field.member_type):
|
||||
raise InvalidFieldTypeError(
|
||||
record,
|
||||
field,
|
||||
self,
|
||||
messages=[f"{field.name}[{key} must be a {field.member_type}, not a {type(member)}."],
|
||||
)
|
||||
return True
|
||||
|
||||
def validate(self, record: Record, field: Field, db: TinyDB = None) -> bool:
|
||||
if record[field.name] is None:
|
||||
return True
|
||||
|
||||
if not isinstance(record[field.name], field.value_type):
|
||||
raise InvalidFieldTypeError(
|
||||
record,
|
||||
field,
|
||||
self,
|
||||
messages=[f"{field.name} must be a {field.value_type}, not a {type(record[field.name])}."],
|
||||
)
|
||||
if not hasattr(field, "member_type"):
|
||||
return True
|
||||
|
||||
if field.value_type == dict:
|
||||
self.validate_dict(record, field, db)
|
||||
elif field.value_type == list:
|
||||
self.validate_list(record, field, db)
|
||||
else:
|
||||
raise RuntimeError("Expected a validation for iterable but didn't get one!")
|
||||
return True
|
||||
|
||||
|
||||
@dataclass
|
||||
class PointerReferenceValidator(Validator):
|
||||
"""
|
||||
Verify that the Pointer is either a string reference to the correct member type,
|
||||
or a record instance of member_type that hasn't been serialized.
|
||||
"""
|
||||
|
||||
def validate_string(self, value: str, type_name: str = "") -> bool:
|
||||
(table, primary_key, val) = value.split("::")
|
||||
if type_name and table != type_name:
|
||||
raise MalformedPointerError(
|
||||
{"string": value},
|
||||
"",
|
||||
self,
|
||||
messages=[f"field should reference '{type_name}', not '{table}'."],
|
||||
)
|
||||
if not primary_key:
|
||||
raise MalformedPointerError(
|
||||
{"string": value},
|
||||
"",
|
||||
self,
|
||||
messages=["Pointers must specify the primary_key field name."],
|
||||
)
|
||||
|
||||
def validate(self, record: Record | str, field: Field, db: TinyDB = None) -> bool:
|
||||
if record[field.name] is None:
|
||||
return True
|
||||
|
||||
if isinstance(record, str):
|
||||
try:
|
||||
self.validate_string(record, field.value_type.__name__)
|
||||
except ValueError:
|
||||
raise MalformedPointerError({field.name: record}, field, self)
|
||||
return True
|
||||
|
||||
if not isinstance(record, Record):
|
||||
raise MalformedPointerError(record, field, self)
|
||||
|
||||
if not isinstance(record[field.name], field.value_type):
|
||||
raise MalformedPointerError(record, field, self)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
@dataclass
|
||||
class UniqueValidator(Validator):
|
||||
def validate(self, record: Record, field: Field, db: TinyDB) -> bool:
|
||||
"""
|
||||
Returns true if the field's value is unique across all records in the table.
|
||||
"""
|
||||
if record[field.name] is None:
|
||||
return True
|
||||
|
||||
query = Query()[field.name].matches(f"^{record[field.name]}$", flags=re.IGNORECASE)
|
||||
table = db.table(record._metadata.table)
|
||||
matches = [dict(match) for match in table.search(query) if match.doc_id != record.doc_id]
|
||||
if matches != []:
|
||||
raise UniqueConstraintError(record, field, self, query=query, matches=matches)
|
||||
return True
|
||||
|
||||
|
||||
@dataclass
|
||||
class LengthValidator(Validator):
|
||||
min: int = 0
|
||||
max: int = 0
|
||||
|
||||
def validate(self, record: Record, field: Field, db: TinyDB = None) -> bool:
|
||||
"""
|
||||
Returns True if the length of the field's value is between min and max, inclusive.
|
||||
"""
|
||||
if record[field.name] is None:
|
||||
return True
|
||||
|
||||
length = len(record[field.name])
|
||||
if length < self.min or length > self.max:
|
||||
raise InvalidLengthError(
|
||||
record,
|
||||
field,
|
||||
self,
|
||||
messages=[f"The field length must be between {self.min} and {self.max}, inclusive."],
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
@dataclass
|
||||
class MinMaxValidator(Validator):
|
||||
min: int = 0
|
||||
max: int = 0
|
||||
|
||||
def validate(self, record: Record, field: Field, db: TinyDB = None) -> bool:
|
||||
"""
|
||||
Returns True if the size of the field's integer value is between min and max, inclusive.
|
||||
"""
|
||||
if record[field.name] is None:
|
||||
return True
|
||||
|
||||
size = int(record[field.name])
|
||||
if size < self.min or size > self.max:
|
||||
raise InvalidSizeError(
|
||||
record,
|
||||
field,
|
||||
self,
|
||||
messages=[f"The field size must be between {self.min} and {self.max}, inclusive."],
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
@dataclass
|
||||
class PatternValidator(Validator):
|
||||
pattern: re.Pattern
|
||||
|
||||
def validate(self, record: Record, field: Field, db: TinyDB = None) -> bool:
|
||||
if record[field.name] is None:
|
||||
return True
|
||||
|
||||
if not self.pattern.match(record[field.name]):
|
||||
raise PatternMatchError(
|
||||
record,
|
||||
field,
|
||||
self,
|
||||
messages=[f"The field value must match the pattern {self.pattern}"],
|
||||
)
|
||||
return True
|
||||
|
|
@ -10,7 +10,14 @@ from tinydb.storages import MemoryStorage
|
|||
|
||||
from grung import examples
|
||||
from grung.db import GrungDB
|
||||
from grung.exceptions import CircularReferenceError, UniqueConstraintError
|
||||
from grung.exceptions import (
|
||||
CircularReferenceError,
|
||||
InvalidFieldTypeError,
|
||||
InvalidLengthError,
|
||||
InvalidSizeError,
|
||||
PatternMatchError,
|
||||
UniqueConstraintError,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -82,7 +89,7 @@ def test_subgroups(db):
|
|||
|
||||
# recursion!
|
||||
with pytest.raises(CircularReferenceError):
|
||||
tos.members = [tos]
|
||||
tos.groups = [tos]
|
||||
db.save(tos)
|
||||
|
||||
|
||||
|
|
@ -201,3 +208,29 @@ def test_file_pointers(db):
|
|||
|
||||
location_on_disk = db.path / album._metadata.fields["review"].relpath(album)
|
||||
assert location_on_disk.read_text() == album.review
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"updates, expected",
|
||||
[
|
||||
({"name": ""}, InvalidLengthError),
|
||||
({"name": "a name longer than 30 characters is what we have here"}, InvalidLengthError),
|
||||
({"name": 23}, InvalidFieldTypeError),
|
||||
({"number": -1}, InvalidSizeError),
|
||||
({"number": 256}, InvalidSizeError),
|
||||
({"email": "foo+alias@"}, PatternMatchError),
|
||||
],
|
||||
ids=[
|
||||
"name too short",
|
||||
"name too long",
|
||||
"name is not a string",
|
||||
"number too small",
|
||||
"number too big",
|
||||
"invalid email addres",
|
||||
],
|
||||
)
|
||||
def test_validators(updates, expected, db):
|
||||
user = db.save(examples.User(name="john", email="john@foo", password="fnord", created=datetime.utcnow()))
|
||||
with pytest.raises(expected):
|
||||
user.update(**updates)
|
||||
db.save(user)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user