Add validators

This commit is contained in:
evilchili 2025-11-01 22:25:15 -07:00
parent b7d7ef9638
commit 64aef7c18b
7 changed files with 795 additions and 481 deletions

View File

@ -1,17 +1,15 @@
import inspect
import re
from collections.abc import Iterable
from functools import reduce
from operator import ior
from pathlib import Path
from typing import List
from tinydb import Query, TinyDB, table
from tinydb import TinyDB, table
from tinydb.storages import MemoryStorage
from tinydb.table import Document
from grung.exceptions import CircularReferenceError, UniqueConstraintError
from grung.types import Record
from grung.exceptions import CircularReferenceError
from grung.objects import Record
from grung.validators import TypeValidator
class RecordTable(table.Table):
@ -26,6 +24,12 @@ class RecordTable(table.Table):
def insert(self, document):
document.before_insert(self.db)
# check field types before attempting serialization
validator = TypeValidator()
for field in document._metadata.fields.values():
validator.validate(document, field, self.db)
doc = document.serialize()
self._check_constraints(doc)
@ -68,29 +72,17 @@ class RecordTable(table.Table):
def _check_constraints(self, document) -> bool:
self._check_for_recursion(document)
self._check_unique(document)
for field in document._metadata.fields.values():
field.validate(document, self.db)
def _check_for_recursion(self, document) -> bool:
ref = document.reference
for field in document._metadata.fields.values():
if isinstance(field.default, Iterable) and ref in document[field.name]:
raise CircularReferenceError(ref, field)
raise CircularReferenceError(document, field, ref, "builtin")
elif document[field.name] == ref:
raise CircularReferenceError(ref, field)
def _check_unique(self, document) -> bool:
matches = []
queries = reduce(
ior,
[
Query()[field.name].matches(f"^{document[field.name]}$", flags=re.IGNORECASE)
for field in document._metadata.fields.values()
if field.unique
],
)
matches = [dict(match) for match in super().search(queries) if match.doc_id != document.doc_id]
if matches != []:
raise UniqueConstraintError(document, queries, matches)
raise CircularReferenceError(document, field, ref, "builtin")
return True
class GrungDB(TinyDB):

View File

@ -1,18 +1,21 @@
from grung.types import (
import re
from grung.objects import (
BackReference,
BinaryFilePointer,
Collection,
DateTime,
Dict,
Field,
Integer,
List,
Password,
Record,
RecordDict,
String,
TextFilePointer,
Timestamp,
)
from grung.validators import LengthValidator, MinMaxValidator, PatternValidator
class User(Record):
@ -20,13 +23,13 @@ class User(Record):
def fields(cls):
return [
*super().fields(),
Field("name", primary_key=True),
Integer("number", default=0),
Field("email", unique=True),
String("name", primary_key=True, validators=[LengthValidator(min=3, max=30)]),
Integer("number", default=0, validators=[MinMaxValidator(min=0, max=255)]),
String("email", unique=True, validators=[PatternValidator(re.compile(r"[^@]+@[\w\-\.]+$"))]),
Password("password"),
DateTime("created"),
Timestamp("last_updated"),
BackReference("groups", Group),
BackReference("groups", value_type=Group),
]
@ -35,10 +38,10 @@ class Group(Record):
def fields(cls):
return [
*super().fields(),
Field("name", primary_key=True),
Collection("members", User),
Collection("groups", Group),
BackReference("parent", Group),
String("name", primary_key=True),
Collection("members", member_type=User),
Collection("groups", member_type=Group),
BackReference("parent", value_type=Group),
]
@ -47,10 +50,10 @@ class Album(Record):
def fields(cls):
inherited = [f for f in super().fields() if f.name != "name"]
return inherited + [
Field("name"),
String("name"),
Dict("credits"),
List("tracks"),
BackReference("artist", Artist),
BackReference("artist", value_type=Artist),
BinaryFilePointer("cover", extension=".jpg"),
TextFilePointer("review"),
]
@ -60,4 +63,4 @@ class Artist(User):
@classmethod
def fields(cls):
inherited = [f for f in super().fields() if f.name != "name"]
return inherited + [Field("name"), RecordDict("albums", Album)]
return inherited + [String("name"), RecordDict("albums", member_type=Album)]

View File

@ -1,42 +1,92 @@
class UniqueConstraintError(Exception):
class ValidationError(Exception):
"""
Thrown when a record's field does not meet validation criteria.
"""
messages = []
template = (
"\n"
" * Record: {_record}\n"
" * Field: {_field}\n"
" * Value: {_value}\n"
" * Validator: {_validator}\n"
"\n"
"{_messages}"
)
def __init__(self, record, field, validator, messages=[], **kwargs):
super().__init__(
self.template.format(
_record=dict(record),
_field=field,
_value=record[field.name],
_validator=validator,
_messages="\n".join(messages or self.__class__.messages),
**kwargs,
)
)
class InvalidFieldTypeError(ValidationError):
"""
Thrown when a document's field value does not match the field value_type.
"""
messages = ["The value of the field is not an instance of the field's value_type."]
class UniqueConstraintError(ValidationError):
"""
Thrown when a db write operation cannot complete due to a field's unique constraint.
"""
def __init__(self, document, query, collisions):
def __init__(self, record, field, validator, query, matches):
super().__init__(
"\n"
f" * Record: {dict(document)}\n"
f" * Query: {query}\n"
f" * Error: Unique constraint failure\n"
" * The record matches the following existing records:\n\n" + "\n".join(str(c) for c in collisions)
record,
field,
validator,
messages=[
f"Query: {query}",
"The record matches the following existing records:\n\n" + "\n".join(str(m) for m in matches),
],
)
class PointerReferenceError(Exception):
"""
Thrown when a document field containing a document could not be resolve to an existing record in the database.
"""
def __init__(self, reference):
super().__init__(
"\n"
f" * Reference: {reference}\n"
f" * Error: Invalid Pointer\n"
" * This collection member does not refer an existing record. Do you need to save it first?"
)
class CircularReferenceError(Exception):
class CircularReferenceError(ValidationError):
"""
Thrown when a record contains a reference to itself.
"""
def __init__(self, reference, field):
super().__init__(
"\n"
f" * Reference: {reference}\n"
f" * Field: {field.name}\n"
f" * Error: Circular Reference\n"
f" * This record contains a reference to itself. This will lead to infinite recursion."
)
messages = ["This record contains a reference to itself. This will lead to infinite recursion."]
class MalformedPointerError(ValidationError):
"""
Thrown when a Pointer's value is not a valid reference string.
"""
messages = ["A Pointer's value must follow the format 'TABLE_NAME::PRIMARY_KEY_NAME::PRIMARY_KEY_VALUE'."]
class PointerReferenceError(Exception):
"""
Thrown when a record field containing a record could not be resolve to an existing record in the database.
"""
class InvalidLengthError(ValidationError):
"""
Thrown when a field does not meet its length constraint.
"""
class InvalidSizeError(ValidationError):
"""
Thrown when a field's size is too large or too small.
"""
class PatternMatchError(ValidationError):
"""
Thrown when a field does not match the specified pattern.
"""

448
src/grung/objects.py Normal file
View File

@ -0,0 +1,448 @@
from __future__ import annotations
import hashlib
import hmac
import os
import re
import typing
from collections import namedtuple
from dataclasses import dataclass, field
from datetime import datetime
from pathlib import Path
import nanoid
from tinydb import TinyDB, where
import grung.types
from grung.exceptions import PointerReferenceError
from grung.validators import PointerReferenceValidator, UniqueValidator
Metadata = namedtuple("Metadata", ["table", "fields", "backrefs", "primary_key"])
@dataclass
class Field(grung.types.Field):
"""
Represents a single field in a Record.
"""
name: str
default: str = None
unique: bool = False
primary_key: bool = False
validators: list = field(default_factory=lambda: [])
value_type = str
def before_insert(self, value: value_type, db: TinyDB, record: Record) -> None:
pass
def after_insert(self, db: TinyDB, record: Record) -> None:
pass
def serialize(self, value: value_type, record: Record | None = None) -> str:
if value is not None:
return str(value)
def deserialize(self, value: str, db: TinyDB, recurse: bool = False) -> value_type:
return value
def validate(self, record: Record, db: TinyDB):
if self.unique:
UniqueValidator().validate(record, self, db)
for validator in self.validators:
validator.validate(record, self, db)
class Record(grung.types.Record):
"""
Base type for a single database record.
"""
def __init__(self, raw_doc: dict = {}, doc_id: int = None, **params):
self.doc_id = doc_id
fields = self.__class__.fields()
pkey = [field for field in fields if field.primary_key]
if len(pkey) > 1:
raise Exception(f"Cannnot have more than one primary key: {pkey}")
elif pkey:
pkey = pkey[0]
else:
# 1% collision rate at ~2M records
pkey = Field("uid", default=nanoid.generate(size=8), primary_key=True)
fields.append(pkey)
pkey.unique = True
self._metadata = Metadata(
table=self.__class__.__name__,
primary_key=pkey.name,
fields={f.name: f for f in fields},
backrefs=lambda value_type: (
field for field in fields if type(field) == BackReference and field.value_type == value_type
),
)
super().__init__(dict({field.name: field.default for field in fields}, **raw_doc, **params))
@classmethod
def fields(cls):
return []
def serialize(self):
"""
Serialize every field on the record
"""
rec = {}
for name, _field in self._metadata.fields.items():
rec[name] = _field.serialize(self[name], record=self) if isinstance(_field, Field) else _field
return self.__class__(rec, doc_id=self.doc_id)
def deserialize(self, db, recurse: bool = True):
"""
Deserialize every field on the record
"""
rec = {}
for name, _field in self._metadata.fields.items():
rec[name] = _field.deserialize(self[name], db, recurse=recurse)
return self.__class__(rec, doc_id=self.doc_id)
def before_insert(self, db: TinyDB) -> None:
for name, _field in self._metadata.fields.items():
_field.before_insert(self[name], db, self)
def after_insert(self, db: TinyDB) -> None:
for name, _field in self._metadata.fields.items():
_field.after_insert(db, self)
def update(self, **data):
for key, value in data.items():
self[key] = value
@property
def reference(self):
return Pointer.reference(self)
@property
def path(self):
return Path(self._metadata.table) / self[self._metadata.primary_key]
def __setattr__(self, key, value):
if key in self:
self[key] = value
super().__setattr__(key, value)
def __getattr__(self, attr_name):
if attr_name in self:
return self.get(attr_name)
raise AttributeError(f"No such attribute: {attr_name}")
def __hash__(self):
return hash(str(dict(self)))
def __repr__(self):
return (
f"{self.__class__.__name__}[{self.doc_id}]("
+ ", ".join([f"{key}={val}" for (key, val) in self.items()])
+ ")"
)
@dataclass
class String(Field):
pass
@dataclass
class Integer(Field):
value_type = int
default: int = 0
def deserialize(self, value: str, db: TinyDB, recurse: bool = False) -> value_type:
return int(value)
@dataclass
class Dict(Field):
default: dict = field(default_factory=lambda: {})
value_type = dict
def serialize(self, values: dict, record: Record | None = None) -> Dict[(str, str)]:
return dict((key, str(value)) for key, value in values.items())
def deserialize(self, values: dict, db: TinyDB, recurse: bool = False) -> Dict[(str, str)]:
return values
@dataclass
class List(Field):
default: list = field(default_factory=lambda: [])
value_type = list
def serialize(self, values: list, record: Record | None = None) -> Dict[(str, str)]:
return values
def deserialize(self, values: list, db: TinyDB, recurse: bool = False) -> typing.List[str]:
return values
@dataclass
class DateTime(Field):
default: datetime = datetime.utcfromtimestamp(0)
value_type = datetime
def serialize(self, value: value_type, record: Record | None = None) -> str:
return (value - datetime.utcfromtimestamp(0)).total_seconds()
def deserialize(self, value: str, db: TinyDB, recurse: bool = False) -> value_type:
return datetime.utcfromtimestamp(int(value))
def before_insert(self, value: value_type, db: TinyDB, record: Record) -> None:
if not value:
record[self.name] = datetime.utcnow().replace(microsecond=0)
@dataclass
class Timestamp(DateTime):
value_type = datetime
def before_insert(self, value: value_type, db: TinyDB, record: Record) -> None:
super().before_insert(None, db, record)
@dataclass
class Password(Field):
value_type = str
default: str = None
# Relatively weak. Consider using stronger initial values in production applications.
salt_size = 4
digest_size = 16
@classmethod
def is_digest(cls, passwd: str):
if not passwd:
return False
offset = 2 * cls.salt_size # each byte is 2 hex chars
try:
if passwd[offset] != ":":
return False
digest = passwd[(offset + 1) :] # noqa
if len(digest) != cls.digest_size * 2:
return False
return re.match(r"^[0-9a-f]+$", digest)
except IndexError:
return False
@classmethod
def get_digest(cls, passwd: str, salt: bytes = None):
if not salt:
salt = os.urandom(cls.salt_size)
digest = hashlib.blake2b(passwd.encode(), digest_size=cls.digest_size, salt=salt).hexdigest()
return digest, salt.hex()
@classmethod
def compare(cls, passwd: value_type, stored: value_type):
stored_salt, stored_digest = stored.split(":")
input_digest, input_salt = cls.get_digest(passwd, bytes.fromhex(stored_salt))
return hmac.compare_digest(input_digest, stored_digest)
def before_insert(self, value: value_type, db: TinyDB, record: Record) -> None:
if value and not self.__class__.is_digest(value):
digest, salt = self.__class__.get_digest(value)
record[self.name] = f"{salt}:{digest}"
@dataclass
class Pointer(Field):
"""
Store a string reference to a record.
"""
name: str = ""
value_type: grung.types.Record = Record
def serialize(self, value: value_type | str, record: Record | None = None) -> str:
return Pointer.reference(value)
def deserialize(self, value: str, db: TinyDB, recurse: bool = True) -> value_type:
return Pointer.dereference(value, db, recurse)
@classmethod
def reference(cls, value: Record | str):
if isinstance(value, str):
PointerReferenceValidator().validate_string(value)
return value
if value:
return f"{value._metadata.table}::{value._metadata.primary_key}::{value[value._metadata.primary_key]}"
return None
@classmethod
def dereference(cls, value: str, db: TinyDB, recurse: bool = True):
if not value:
return
elif type(value) == str:
table_name, pkey, pval = value.split("::")
if pval:
table = db.table(table_name)
rec = table.get(where(pkey) == pval, recurse=recurse)
if not rec:
raise PointerReferenceError(f"Expected a {table_name} with {pkey}=={pval} but did not find one!")
return rec
return value
@dataclass
class BackReference(Pointer):
pass
@dataclass
class BinaryFilePointer(Field):
"""
Write the contents of this field to disk and store the path in the db.
"""
name: str
extension: str = ".blob"
value_type = bytes
def relpath(self, record):
return Path(record._metadata.table) / record[record._metadata.primary_key] / f"{self.name}{self.extension}"
def reference(self, record):
return f"/::{self.relpath(record)}"
def dereference(self, reference, db):
relpath = reference.replace("/::", "", 1)
try:
return (db.path / relpath).read_bytes()
except FileNotFoundError:
return None
def serialize(self, value: value_type | str, record: Record | None = None) -> str:
return self.reference(record)
def deserialize(self, value: str, db: TinyDB, recurse: bool = True) -> value_type:
if not value:
return None
return self.dereference(value, db)
def prepare(self, data: value_type):
"""
Return bytes to be written to disk
"""
if not data:
return
if not isinstance(data, self.value_type):
return data.encode()
return data
def before_insert(self, value: value_type, db: TinyDB, record: Record) -> None:
if not value:
return
relpath = self.relpath(record)
path = db.path / relpath
path.parent.mkdir(parents=True, exist_ok=True)
path.write_bytes(self.prepare(value))
@dataclass
class TextFilePointer(BinaryFilePointer):
"""
Write the contents of this field to disk and store the path in the db.
"""
name: str
extension: str = ".txt"
value_type = str
def prepare(self, data: value_type):
if isinstance(data, bytes):
return data
return str(data).encode()
def deserialize(self, value: str, db: TinyDB, recurse: bool = True) -> value_type:
if not value:
return None
buf = super().deserialize(value, db)
return buf.decode() if buf else None
@dataclass
class Collection(Field):
"""
A collection of pointers.
"""
default: typing.List[value_type] = field(default_factory=lambda: [])
member_type: type = Record
value_type = list
def serialize(self, values: typing.List[value_type], record: Record | None = None) -> typing.List[str]:
return [Pointer.reference(val) for val in values]
def deserialize(self, values: typing.List[str], db: TinyDB, recurse: bool = False) -> typing.List[value_type]:
"""
Recursively deserialize the objects in this collection
"""
recs = []
if not recurse:
return values
for val in values:
recs.append(Pointer.dereference(val, db=db, recurse=False))
return recs
def after_insert(self, db: TinyDB, record: Record) -> None:
"""
Populate any backreferences in the members of this collection with the parent record's uid.
"""
if not record[self.name]:
return
for member in record[self.name]:
target = Pointer.dereference(member, db=db, recurse=False)
for backref in target._metadata.backrefs(type(record)):
target[backref.name] = record
db.save(target)
@dataclass
class RecordDict(Field):
default: typing.Dict[(str, Record)] = field(default_factory=lambda: {})
member_type: type = Record
value_type = dict
def serialize(
self, values: typing.Dict[(str, value_type)], record: Record | None = None
) -> typing.Dict[(str, str)]:
return dict((key, Pointer.reference(val)) for (key, val) in values.items())
def deserialize(
self, values: typing.Dict[(str, str)], db: TinyDB, recurse: bool = False
) -> typing.Dict[(str, value_type)]:
if not recurse:
return values
return dict((key, Pointer.dereference(val, db=db, recurse=False)) for (key, val) in values.items())
def after_insert(self, db: TinyDB, record: Record) -> None:
"""
Populate any backreferences in the members of this mapping with the parent record's uid.
"""
if not record[self.name]:
return
for key, pointer in record[self.name].items():
target = Pointer.dereference(pointer, db=db, recurse=False)
for backref in target._metadata.backrefs(type(record)):
target[backref.name] = record
db.save(target)

View File

@ -1,420 +1,9 @@
from __future__ import annotations
import hashlib
import hmac
import os
import re
import typing
from collections import namedtuple
from dataclasses import dataclass, field
from datetime import datetime
from pathlib import Path
import nanoid
from tinydb import TinyDB, where
from grung.exceptions import PointerReferenceError
Metadata = namedtuple("Metadata", ["table", "fields", "backrefs", "primary_key"])
@dataclass
class Field:
"""
Represents a single field in a Record.
"""
name: str
value_type: type = str
default: str = None
unique: bool = False
primary_key: bool = False
def before_insert(self, value: value_type, db: TinyDB, record: Record) -> None:
pass
def after_insert(self, db: TinyDB, record: Record) -> None:
pass
def serialize(self, value: value_type, record: Record | None = None) -> str:
if value is not None:
return str(value)
def deserialize(self, value: str, db: TinyDB, recurse: bool = False) -> value_type:
return value
@dataclass
class Integer(Field):
value_type = int
default: int = 0
def deserialize(self, value: str, db: TinyDB, recurse: bool = False) -> value_type:
return int(value)
@dataclass
class Dict(Field):
value_type: type = str
default: dict = field(default_factory=lambda: {})
def serialize(self, values: dict, record: Record | None = None) -> Dict[(str, str)]:
return dict((key, str(value)) for key, value in values.items())
def deserialize(self, values: dict, db: TinyDB, recurse: bool = False) -> Dict[(str, str)]:
return values
@dataclass
class List(Field):
value_type: type = list
default: list = field(default_factory=lambda: [])
def serialize(self, values: list, record: Record | None = None) -> Dict[(str, str)]:
return values
def deserialize(self, values: list, db: TinyDB, recurse: bool = False) -> typing.List[str]:
return values
@dataclass
class DateTime(Field):
value_type: datetime
default: datetime = datetime.utcfromtimestamp(0)
def serialize(self, value: value_type, record: Record | None = None) -> str:
return (value - datetime.utcfromtimestamp(0)).total_seconds()
def deserialize(self, value: str, db: TinyDB, recurse: bool = False) -> value_type:
return datetime.utcfromtimestamp(int(value))
def before_insert(self, value: value_type, db: TinyDB, record: Record) -> None:
if not value:
record[self.name] = datetime.utcnow().replace(microsecond=0)
@dataclass
class Timestamp(DateTime):
value_type: datetime
def before_insert(self, value: value_type, db: TinyDB, record: Record) -> None:
super().before_insert(None, db, record)
@dataclass
class Password(Field):
value_type = str
default: str = None
# Relatively weak. Consider using stronger initial values in production applications.
salt_size = 4
digest_size = 16
@classmethod
def is_digest(cls, passwd: str):
if not passwd:
return False
offset = 2 * cls.salt_size # each byte is 2 hex chars
try:
if passwd[offset] != ":":
return False
digest = passwd[(offset + 1) :] # noqa
if len(digest) != cls.digest_size * 2:
return False
return re.match(r"^[0-9a-f]+$", digest)
except IndexError:
return False
@classmethod
def get_digest(cls, passwd: str, salt: bytes = None):
if not salt:
salt = os.urandom(cls.salt_size)
digest = hashlib.blake2b(passwd.encode(), digest_size=cls.digest_size, salt=salt).hexdigest()
return digest, salt.hex()
@classmethod
def compare(cls, passwd: value_type, stored: value_type):
stored_salt, stored_digest = stored.split(":")
input_digest, input_salt = cls.get_digest(passwd, bytes.fromhex(stored_salt))
return hmac.compare_digest(input_digest, stored_digest)
def before_insert(self, value: value_type, db: TinyDB, record: Record) -> None:
if value and not self.__class__.is_digest(value):
digest, salt = self.__class__.get_digest(value)
record[self.name] = f"{salt}:{digest}"
class Record(typing.Dict[(str, Field)]):
"""
Base type for a single database record.
"""
def __init__(self, raw_doc: dict = {}, doc_id: int = None, **params):
self.doc_id = doc_id
fields = self.__class__.fields()
pkey = [field for field in fields if field.primary_key]
if len(pkey) > 1:
raise Exception(f"Cannnot have more than one primary key: {pkey}")
elif pkey:
pkey = pkey[0]
else:
# 1% collision rate at ~2M records
pkey = Field("uid", default=nanoid.generate(size=8), primary_key=True)
fields.append(pkey)
pkey.unique = True
self._metadata = Metadata(
table=self.__class__.__name__,
primary_key=pkey.name,
fields={f.name: f for f in fields},
backrefs=lambda value_type: (
field for field in fields if type(field) == BackReference and field.value_type == value_type
),
)
super().__init__(dict({field.name: field.default for field in fields}, **raw_doc, **params))
@classmethod
def fields(cls):
return []
def serialize(self):
"""
Serialize every field on the record
"""
rec = {}
for name, _field in self._metadata.fields.items():
rec[name] = _field.serialize(self[name], record=self) if isinstance(_field, Field) else _field
return self.__class__(rec, doc_id=self.doc_id)
def deserialize(self, db, recurse: bool = True):
"""
Deserialize every field on the record
"""
rec = {}
for name, _field in self._metadata.fields.items():
rec[name] = _field.deserialize(self[name], db, recurse=recurse)
return self.__class__(rec, doc_id=self.doc_id)
def before_insert(self, db: TinyDB) -> None:
for name, _field in self._metadata.fields.items():
_field.before_insert(self[name], db, self)
def after_insert(self, db: TinyDB) -> None:
for name, _field in self._metadata.fields.items():
_field.after_insert(db, self)
@property
def reference(self):
return Pointer.reference(self)
@property
def path(self):
return Path(self._metadata.table) / self[self._metadata.primary_key]
def __setattr__(self, key, value):
if key in self:
self[key] = value
super().__setattr__(key, value)
def __getattr__(self, attr_name):
if attr_name in self:
return self.get(attr_name)
raise AttributeError(f"No such attribute: {attr_name}")
def __hash__(self):
return hash(str(dict(self)))
def __repr__(self):
return (
f"{self.__class__.__name__}[{self.doc_id}]("
+ ", ".join([f"{key}={val}" for (key, val) in self.items()])
+ ")"
)
@dataclass
class Pointer(Field):
"""
Store a string reference to a record.
"""
name: str = ""
value_type: type = Record
def serialize(self, value: value_type | str, record: Record | None = None) -> str:
return Pointer.reference(value)
def deserialize(self, value: str, db: TinyDB, recurse: bool = True) -> value_type:
return Pointer.dereference(value, db, recurse)
@classmethod
def reference(cls, value: Record | str):
# XXX This could be smarter
if isinstance(value, str):
if "::" not in value:
raise PointerReferenceError("Value {value} does not look like a reference!")
return value
if value:
return f"{value._metadata.table}::{value._metadata.primary_key}::{value[value._metadata.primary_key]}"
return None
@classmethod
def dereference(cls, value: str, db: TinyDB, recurse: bool = True):
if not value:
return
elif type(value) == str:
table_name, pkey, pval = value.split("::")
if pval:
table = db.table(table_name)
rec = table.get(where(pkey) == pval, recurse=recurse)
if not rec:
raise PointerReferenceError(f"Expected a {table_name} with {pkey}=={pval} but did not find one!")
return rec
return value
@dataclass
class BackReference(Pointer):
pass
@dataclass
class BinaryFilePointer(Field):
"""
Write the contents of this field to disk and store the path in the db.
"""
name: str
value_type: type = bytes
extension: str = ".blob"
def relpath(self, record):
return Path(record._metadata.table) / record[record._metadata.primary_key] / f"{self.name}{self.extension}"
def reference(self, record):
return f"/::{self.relpath(record)}"
def dereference(self, reference, db):
relpath = reference.replace("/::", "", 1)
try:
return (db.path / relpath).read_bytes()
except FileNotFoundError:
return None
def serialize(self, value: value_type | str, record: Record | None = None) -> str:
return self.reference(record)
def deserialize(self, value: str, db: TinyDB, recurse: bool = True) -> value_type:
if not value:
return None
return self.dereference(value, db)
def prepare(self, data: value_type):
"""
Return bytes to be written to disk
"""
if not data:
return
if not isinstance(data, self.value_type):
return data.encode()
return data
def before_insert(self, value: value_type, db: TinyDB, record: Record) -> None:
if not value:
return
relpath = self.relpath(record)
path = db.path / relpath
path.parent.mkdir(parents=True, exist_ok=True)
path.write_bytes(self.prepare(value))
@dataclass
class TextFilePointer(BinaryFilePointer):
"""
Write the contents of this field to disk and store the path in the db.
"""
name: str
value_type: type = str
extension: str = ".txt"
def prepare(self, data: value_type):
if isinstance(data, bytes):
return data
return str(data).encode()
def deserialize(self, value: str, db: TinyDB, recurse: bool = True) -> value_type:
if not value:
return None
buf = super().deserialize(value, db)
return buf.decode() if buf else None
@dataclass
class Collection(Field):
"""
A collection of pointers.
"""
value_type: type = Record
default: typing.List[value_type] = field(default_factory=lambda: [])
def serialize(self, values: typing.List[value_type], record: Record | None = None) -> typing.List[str]:
return [Pointer.reference(val) for val in values]
def deserialize(self, values: typing.List[str], db: TinyDB, recurse: bool = False) -> typing.List[value_type]:
"""
Recursively deserialize the objects in this collection
"""
recs = []
if not recurse:
return values
for val in values:
recs.append(Pointer.dereference(val, db=db, recurse=False))
return recs
def after_insert(self, db: TinyDB, record: Record) -> None:
"""
Populate any backreferences in the members of this collection with the parent record's uid.
"""
if not record[self.name]:
return
for member in record[self.name]:
target = Pointer.dereference(member, db=db, recurse=False)
for backref in target._metadata.backrefs(type(record)):
target[backref.name] = record
db.save(target)
@dataclass
class RecordDict(Field):
value_type: type = Record
default: typing.Dict[(str, Record)] = field(default_factory=lambda: {})
def serialize(
self, values: typing.Dict[(str, value_type)], record: Record | None = None
) -> typing.Dict[(str, str)]:
return dict((key, Pointer.reference(val)) for (key, val) in values.items())
def deserialize(
self, values: typing.Dict[(str, str)], db: TinyDB, recurse: bool = False
) -> typing.Dict[(str, value_type)]:
if not recurse:
return values
return dict((key, Pointer.dereference(val, db=db, recurse=False)) for (key, val) in values.items())
def after_insert(self, db: TinyDB, record: Record) -> None:
"""
Populate any backreferences in the members of this mapping with the parent record's uid.
"""
if not record[self.name]:
return
for key, pointer in record[self.name].items():
target = Pointer.dereference(pointer, db=db, recurse=False)
for backref in target._metadata.backrefs(type(record)):
target[backref.name] = record
db.save(target)

199
src/grung/validators.py Normal file
View File

@ -0,0 +1,199 @@
from __future__ import annotations
import re
from dataclasses import dataclass
from tinydb import Query, TinyDB
from grung.exceptions import (
InvalidFieldTypeError,
InvalidLengthError,
InvalidSizeError,
MalformedPointerError,
PatternMatchError,
UniqueConstraintError,
ValidationError,
)
from grung.types import Field, Record
@dataclass
class Validator:
def validate(self, record: Record, field: Field, db: TinyDB = None) -> bool:
raise ValidationError(record, field, self)
@dataclass
class TypeValidator:
def validate_list(self, record: Record, field: Field, db: TinyDB = None) -> bool:
messages = []
for i in range(len(record[field.name])):
member = record[field.name][i]
if isinstance(member, str):
class_name = field.member_type.__name__
try:
return PointerReferenceValidator().validate_string(member, class_name)
except MalformedPointerError as e:
messages.append(str(e))
elif not isinstance(member, field.member_type):
messages.append(f"{field.name}[{i}] must be a {field.member_type}, not a {type(member)}.")
if messages:
raise InvalidFieldTypeError(record, field, self, messages=messages)
def validate_dict(self, record: Record, field: Field, db: TinyDB = None) -> bool:
for key, member in record[field.name].items():
if not isinstance(member, field.member_type):
raise InvalidFieldTypeError(
record,
field,
self,
messages=[f"{field.name}[{key} must be a {field.member_type}, not a {type(member)}."],
)
return True
def validate(self, record: Record, field: Field, db: TinyDB = None) -> bool:
if record[field.name] is None:
return True
if not isinstance(record[field.name], field.value_type):
raise InvalidFieldTypeError(
record,
field,
self,
messages=[f"{field.name} must be a {field.value_type}, not a {type(record[field.name])}."],
)
if not hasattr(field, "member_type"):
return True
if field.value_type == dict:
self.validate_dict(record, field, db)
elif field.value_type == list:
self.validate_list(record, field, db)
else:
raise RuntimeError("Expected a validation for iterable but didn't get one!")
return True
@dataclass
class PointerReferenceValidator(Validator):
"""
Verify that the Pointer is either a string reference to the correct member type,
or a record instance of member_type that hasn't been serialized.
"""
def validate_string(self, value: str, type_name: str = "") -> bool:
(table, primary_key, val) = value.split("::")
if type_name and table != type_name:
raise MalformedPointerError(
{"string": value},
"",
self,
messages=[f"field should reference '{type_name}', not '{table}'."],
)
if not primary_key:
raise MalformedPointerError(
{"string": value},
"",
self,
messages=["Pointers must specify the primary_key field name."],
)
def validate(self, record: Record | str, field: Field, db: TinyDB = None) -> bool:
if record[field.name] is None:
return True
if isinstance(record, str):
try:
self.validate_string(record, field.value_type.__name__)
except ValueError:
raise MalformedPointerError({field.name: record}, field, self)
return True
if not isinstance(record, Record):
raise MalformedPointerError(record, field, self)
if not isinstance(record[field.name], field.value_type):
raise MalformedPointerError(record, field, self)
return True
@dataclass
class UniqueValidator(Validator):
def validate(self, record: Record, field: Field, db: TinyDB) -> bool:
"""
Returns true if the field's value is unique across all records in the table.
"""
if record[field.name] is None:
return True
query = Query()[field.name].matches(f"^{record[field.name]}$", flags=re.IGNORECASE)
table = db.table(record._metadata.table)
matches = [dict(match) for match in table.search(query) if match.doc_id != record.doc_id]
if matches != []:
raise UniqueConstraintError(record, field, self, query=query, matches=matches)
return True
@dataclass
class LengthValidator(Validator):
min: int = 0
max: int = 0
def validate(self, record: Record, field: Field, db: TinyDB = None) -> bool:
"""
Returns True if the length of the field's value is between min and max, inclusive.
"""
if record[field.name] is None:
return True
length = len(record[field.name])
if length < self.min or length > self.max:
raise InvalidLengthError(
record,
field,
self,
messages=[f"The field length must be between {self.min} and {self.max}, inclusive."],
)
return True
@dataclass
class MinMaxValidator(Validator):
min: int = 0
max: int = 0
def validate(self, record: Record, field: Field, db: TinyDB = None) -> bool:
"""
Returns True if the size of the field's integer value is between min and max, inclusive.
"""
if record[field.name] is None:
return True
size = int(record[field.name])
if size < self.min or size > self.max:
raise InvalidSizeError(
record,
field,
self,
messages=[f"The field size must be between {self.min} and {self.max}, inclusive."],
)
return True
@dataclass
class PatternValidator(Validator):
pattern: re.Pattern
def validate(self, record: Record, field: Field, db: TinyDB = None) -> bool:
if record[field.name] is None:
return True
if not self.pattern.match(record[field.name]):
raise PatternMatchError(
record,
field,
self,
messages=[f"The field value must match the pattern {self.pattern}"],
)
return True

View File

@ -10,7 +10,14 @@ from tinydb.storages import MemoryStorage
from grung import examples
from grung.db import GrungDB
from grung.exceptions import CircularReferenceError, UniqueConstraintError
from grung.exceptions import (
CircularReferenceError,
InvalidFieldTypeError,
InvalidLengthError,
InvalidSizeError,
PatternMatchError,
UniqueConstraintError,
)
@pytest.fixture
@ -82,7 +89,7 @@ def test_subgroups(db):
# recursion!
with pytest.raises(CircularReferenceError):
tos.members = [tos]
tos.groups = [tos]
db.save(tos)
@ -201,3 +208,29 @@ def test_file_pointers(db):
location_on_disk = db.path / album._metadata.fields["review"].relpath(album)
assert location_on_disk.read_text() == album.review
@pytest.mark.parametrize(
"updates, expected",
[
({"name": ""}, InvalidLengthError),
({"name": "a name longer than 30 characters is what we have here"}, InvalidLengthError),
({"name": 23}, InvalidFieldTypeError),
({"number": -1}, InvalidSizeError),
({"number": 256}, InvalidSizeError),
({"email": "foo+alias@"}, PatternMatchError),
],
ids=[
"name too short",
"name too long",
"name is not a string",
"number too small",
"number too big",
"invalid email addres",
],
)
def test_validators(updates, expected, db):
user = db.save(examples.User(name="john", email="john@foo", password="fnord", created=datetime.utcnow()))
with pytest.raises(expected):
user.update(**updates)
db.save(user)