Implement FilePointer

This commit is contained in:
evilchili 2025-10-18 14:23:16 -07:00
parent 7e05915540
commit 8bc3f07b28
4 changed files with 69 additions and 24 deletions

View File

@ -2,9 +2,11 @@ import inspect
import re
from functools import reduce
from operator import ior
from pathlib import Path
from typing import List
from tinydb import Query, TinyDB, table
from tinydb.storages import MemoryStorage
from tinydb.table import Document
from grung.exceptions import UniqueConstraintError
@ -18,11 +20,11 @@ class RecordTable(table.Table):
def __init__(self, name: str, db: TinyDB, document_class: Document = Record, **kwargs):
self.document_class = document_class
self._db = db
self.db = db
super().__init__(db.storage, name, **kwargs)
def insert(self, document):
document.before_insert(self._db)
document.before_insert(self.db)
doc = document.serialize()
self._check_constraints(doc)
@ -31,8 +33,8 @@ class RecordTable(table.Table):
else:
last_insert_id = super().insert(dict(doc))
doc.doc_id = last_insert_id
doc.after_insert(self._db)
return doc.deserialize(self._db)
doc.after_insert(self.db)
return doc.deserialize(self.db)
def get(self, *args, doc_id: int = None, recurse: bool = False, **kwargs):
"""
@ -48,7 +50,7 @@ class RecordTable(table.Table):
if doc_id:
document = super().get(doc_id=doc_id)
if document:
return document.deserialize(self._db, recurse=recurse)
return document.deserialize(self.db, recurse=recurse)
matches = self.search(*args, recurse=recurse, **kwargs)
if matches:
@ -56,7 +58,7 @@ class RecordTable(table.Table):
def search(self, *args, recurse: bool = False, **kwargs) -> List[Record]:
results = super().search(*args, **kwargs)
return [r.deserialize(self._db, recurse=recurse) for r in results]
return [r.deserialize(self.db, recurse=recurse) for r in results]
def remove(self, document):
if document.doc_id:
@ -89,7 +91,10 @@ class GrungDB(TinyDB):
default_table_name = "Record"
_tables = {}
def __init__(self, *args, **kwargs):
def __init__(self, path: Path, *args, **kwargs):
self.path = path
if kwargs["storage"] != MemoryStorage:
args = (path,) + args
super().__init__(*args, **kwargs)
self.create_table(Record)
@ -122,8 +127,8 @@ class GrungDB(TinyDB):
return super().__getattr__(attr_name)
@classmethod
def with_schema(cls, schema_module, *args, **kwargs):
db = GrungDB(*args, **kwargs)
def with_schema(cls, schema_module, path: Path | None, *args, **kwargs):
db = GrungDB(path=path, *args, **kwargs)
for name, obj in inspect.getmembers(schema_module):
if type(obj) == type and issubclass(obj, Record):
db.create_table(obj)

View File

@ -4,6 +4,7 @@ from grung.types import (
DateTime,
Dict,
Field,
FilePointer,
Integer,
List,
Password,
@ -49,6 +50,7 @@ class Album(Record):
Dict("credits"),
List("tracks"),
BackReference("artist", Artist),
FilePointer("cover", extension=".jpg"),
]
@ -56,7 +58,4 @@ class Artist(User):
@classmethod
def fields(cls):
inherited = [f for f in super().fields() if f.name != "name"]
return inherited + [
Field("name"),
RecordDict("albums", Album),
]
return inherited + [Field("name"), RecordDict("albums", Album)]

View File

@ -8,6 +8,7 @@ import typing
from collections import namedtuple
from dataclasses import dataclass, field
from datetime import datetime
from pathlib import Path
import nanoid
from tinydb import TinyDB, where
@ -35,7 +36,7 @@ class Field:
def after_insert(self, db: TinyDB, record: Record) -> None:
pass
def serialize(self, value: value_type) -> str:
def serialize(self, value: value_type, record: Record | None = None) -> str:
if value is not None:
return str(value)
@ -57,7 +58,7 @@ class Dict(Field):
value_type: type = str
default: dict = field(default_factory=lambda: {})
def serialize(self, values: dict) -> Dict[(str, str)]:
def serialize(self, values: dict, record: Record | None = None) -> Dict[(str, str)]:
return dict((key, str(value)) for key, value in values.items())
def deserialize(self, values: dict, db: TinyDB, recurse: bool = False) -> Dict[(str, str)]:
@ -69,7 +70,7 @@ class List(Field):
value_type: type = list
default: list = field(default_factory=lambda: [])
def serialize(self, values: list) -> Dict[(str, str)]:
def serialize(self, values: list, record: Record | None = None) -> Dict[(str, str)]:
return values
def deserialize(self, values: list, db: TinyDB, recurse: bool = False) -> typing.List[str]:
@ -81,7 +82,7 @@ class DateTime(Field):
value_type: datetime
default: datetime = datetime.utcfromtimestamp(0)
def serialize(self, value: value_type) -> str:
def serialize(self, value: value_type, record: Record | None = None) -> str:
return (value - datetime.utcfromtimestamp(0)).total_seconds()
def deserialize(self, value: str, db: TinyDB, recurse: bool = False) -> value_type:
@ -184,7 +185,7 @@ class Record(typing.Dict[(str, Field)]):
"""
rec = {}
for name, _field in self._metadata.fields.items():
rec[name] = _field.serialize(self[name]) if isinstance(_field, Field) else _field
rec[name] = _field.serialize(self[name], record=self) if isinstance(_field, Field) else _field
return self.__class__(rec, doc_id=self.doc_id)
def deserialize(self, db, recurse: bool = True):
@ -208,6 +209,10 @@ class Record(typing.Dict[(str, Field)]):
def reference(self):
return Pointer.reference(self)
@property
def path(self):
return Path(self._metadata.table) / self[self._metadata.primary_key]
def __setattr__(self, key, value):
if key in self:
self[key] = value
@ -238,7 +243,7 @@ class Pointer(Field):
name: str = ""
value_type: type = Record
def serialize(self, value: value_type | str) -> str:
def serialize(self, value: value_type | str, record: Record | None = None) -> str:
return Pointer.reference(value)
def deserialize(self, value: str, db: TinyDB, recurse: bool = True) -> value_type:
@ -277,6 +282,32 @@ class BackReference(Pointer):
pass
@dataclass
class FilePointer(Field):
"""
Write the contents of this field to disk and store the path in the db.
"""
name: str
value_type: type = bytes
extension: str = ".txt"
def relpath(self, record):
return Path(record._metadata.table) / record[record._metadata.primary_key] / f"{self.name}{self.extension}"
def deserialize(self, value: str, db: TinyDB, recurse: bool = True) -> value_type:
if not value:
return None
return (db.path / value).read_bytes()
def before_insert(self, value: value_type, db: TinyDB, record: Record) -> None:
relpath = self.relpath(record)
path = db.path / relpath
path.parent.mkdir(parents=True, exist_ok=True)
path.write_bytes(record[self.name])
record[self.name] = relpath
@dataclass
class Collection(Field):
"""
@ -286,7 +317,7 @@ class Collection(Field):
value_type: type = Record
default: typing.List[value_type] = field(default_factory=lambda: [])
def serialize(self, values: typing.List[value_type]) -> typing.List[str]:
def serialize(self, values: typing.List[value_type], record: Record | None = None) -> typing.List[str]:
return [Pointer.reference(val) for val in values]
def deserialize(self, values: typing.List[str], db: TinyDB, recurse: bool = False) -> typing.List[value_type]:
@ -319,7 +350,9 @@ class RecordDict(Field):
value_type: type = Record
default: typing.Dict[(str, Record)] = field(default_factory=lambda: {})
def serialize(self, values: typing.Dict[(str, value_type)]) -> typing.Dict[(str, str)]:
def serialize(
self, values: typing.Dict[(str, value_type)], record: Record | None = None
) -> typing.Dict[(str, str)]:
return dict((key, Pointer.reference(val)) for (key, val) in values.items())
def deserialize(

View File

@ -1,4 +1,6 @@
import tempfile
from datetime import datetime
from pathlib import Path
from pprint import pprint as print
from time import sleep
@ -13,9 +15,10 @@ from grung.exceptions import PointerReferenceError, UniqueConstraintError
@pytest.fixture
def db():
_db = GrungDB.with_schema(examples, storage=MemoryStorage)
yield _db
print(_db)
with tempfile.TemporaryDirectory() as path:
_db = GrungDB.with_schema(examples, path=Path(path), storage=MemoryStorage)
yield _db
print(_db)
def test_crud(db):
@ -153,10 +156,12 @@ def test_mapping(db):
name="The Impossible Kid",
credits={"Produced By": "Aesop Rock", "Lyrics By": "Aesop Rock", "Puke in the MeowMix By": "Kirby"},
tracks=["Mystery Fish", "Rings", "Lotta Years", "Dorks"],
cover=b"some jpg data",
)
)
assert album.credits["Produced By"] == "Aesop Rock"
assert album.tracks[0] == "Mystery Fish"
assert album.cover == b"some jpg data"
aes = db.save(
examples.Artist(
@ -170,3 +175,6 @@ def test_mapping(db):
assert album.name in aes.albums
assert aes.albums[album.name].uid == album.uid
assert "Kirby" in aes.albums[album.name].credits.values()
location_on_disk = db.path / album._metadata.fields["cover"].relpath(album)
assert location_on_disk.read_bytes() == album.cover