add constraint check for recursion

This commit is contained in:
evilchili 2025-10-30 19:56:56 -07:00
parent cdce235ecd
commit 5c617e7c5c
4 changed files with 32 additions and 12 deletions

View File

@ -1,5 +1,6 @@
import inspect import inspect
import re import re
from collections.abc import Iterable
from functools import reduce from functools import reduce
from operator import ior from operator import ior
from pathlib import Path from pathlib import Path
@ -9,7 +10,7 @@ from tinydb import Query, TinyDB, table
from tinydb.storages import MemoryStorage from tinydb.storages import MemoryStorage
from tinydb.table import Document from tinydb.table import Document
from grung.exceptions import UniqueConstraintError from grung.exceptions import CircularReferenceError, UniqueConstraintError
from grung.types import Record from grung.types import Record
@ -66,8 +67,15 @@ class RecordTable(table.Table):
super().remove(doc_ids=[document.doc_id]) super().remove(doc_ids=[document.doc_id])
def _check_constraints(self, document) -> bool: def _check_constraints(self, document) -> bool:
self._check_for_recursion(document)
self._check_unique(document) self._check_unique(document)
def _check_for_recursion(self, document) -> bool:
ref = document.reference
for field in document._metadata.fields.values():
if isinstance(field.default, Iterable) and ref in document[field.name]:
raise CircularReferenceError(ref, field)
def _check_unique(self, document) -> bool: def _check_unique(self, document) -> bool:
matches = [] matches = []
queries = reduce( queries = reduce(

View File

@ -25,3 +25,18 @@ class PointerReferenceError(Exception):
f" * Error: Invalid Pointer\n" f" * Error: Invalid Pointer\n"
" * This collection member does not refer an existing record. Do you need to save it first?" " * This collection member does not refer an existing record. Do you need to save it first?"
) )
class CircularReferenceError(Exception):
"""
Thrown when a record contains a reference to itself.
"""
def __init__(self, reference, field):
super().__init__(
"\n"
f" * Reference: {reference}\n"
f" * Field: {field.name}\n"
f" * Error: Circular Reference\n"
f" * This record contains a reference to itself. This will lead to infinite recursion."
)

View File

@ -257,8 +257,6 @@ class Pointer(Field):
raise PointerReferenceError("Value {value} does not look like a reference!") raise PointerReferenceError("Value {value} does not look like a reference!")
return value return value
if value: if value:
if not value.doc_id:
raise PointerReferenceError(value)
return f"{value._metadata.table}::{value._metadata.primary_key}::{value[value._metadata.primary_key]}" return f"{value._metadata.table}::{value._metadata.primary_key}::{value[value._metadata.primary_key]}"
return None return None

View File

@ -10,7 +10,7 @@ from tinydb.storages import MemoryStorage
from grung import examples from grung import examples
from grung.db import GrungDB from grung.db import GrungDB
from grung.exceptions import PointerReferenceError, UniqueConstraintError from grung.exceptions import CircularReferenceError, UniqueConstraintError
@pytest.fixture @pytest.fixture
@ -49,13 +49,7 @@ def test_crud(db):
def test_pointers(db): def test_pointers(db):
user = examples.User(name="john", email="john@foo") user = db.save(examples.User(name="john", email="john@foo"))
players = examples.Group(name="players", members=[user])
with pytest.raises(PointerReferenceError):
players = db.save(players)
user = db.save(user)
players = db.save(examples.Group(name="players", members=[user])) players = db.save(examples.Group(name="players", members=[user]))
user = db.table("User").get(doc_id=user.doc_id) user = db.table("User").get(doc_id=user.doc_id)
@ -86,6 +80,11 @@ def test_subgroups(db):
kirk = db.table("User").get(doc_id=kirk.doc_id) kirk = db.table("User").get(doc_id=kirk.doc_id)
assert kirk.reference in unique_users assert kirk.reference in unique_users
# recursion!
with pytest.raises(CircularReferenceError):
tos.members = [tos]
db.save(tos)
def test_unique(db): def test_unique(db):
user1 = examples.User(name="john", email="john@foo") user1 = examples.User(name="john", email="john@foo")