Implement FilePointer

This commit is contained in:
evilchili 2025-10-18 14:23:16 -07:00
parent 7e05915540
commit 8bc3f07b28
4 changed files with 69 additions and 24 deletions

View File

@ -2,9 +2,11 @@ import inspect
import re import re
from functools import reduce from functools import reduce
from operator import ior from operator import ior
from pathlib import Path
from typing import List from typing import List
from tinydb import Query, TinyDB, table from tinydb import Query, TinyDB, table
from tinydb.storages import MemoryStorage
from tinydb.table import Document from tinydb.table import Document
from grung.exceptions import UniqueConstraintError from grung.exceptions import UniqueConstraintError
@ -18,11 +20,11 @@ class RecordTable(table.Table):
def __init__(self, name: str, db: TinyDB, document_class: Document = Record, **kwargs): def __init__(self, name: str, db: TinyDB, document_class: Document = Record, **kwargs):
self.document_class = document_class self.document_class = document_class
self._db = db self.db = db
super().__init__(db.storage, name, **kwargs) super().__init__(db.storage, name, **kwargs)
def insert(self, document): def insert(self, document):
document.before_insert(self._db) document.before_insert(self.db)
doc = document.serialize() doc = document.serialize()
self._check_constraints(doc) self._check_constraints(doc)
@ -31,8 +33,8 @@ class RecordTable(table.Table):
else: else:
last_insert_id = super().insert(dict(doc)) last_insert_id = super().insert(dict(doc))
doc.doc_id = last_insert_id doc.doc_id = last_insert_id
doc.after_insert(self._db) doc.after_insert(self.db)
return doc.deserialize(self._db) return doc.deserialize(self.db)
def get(self, *args, doc_id: int = None, recurse: bool = False, **kwargs): def get(self, *args, doc_id: int = None, recurse: bool = False, **kwargs):
""" """
@ -48,7 +50,7 @@ class RecordTable(table.Table):
if doc_id: if doc_id:
document = super().get(doc_id=doc_id) document = super().get(doc_id=doc_id)
if document: if document:
return document.deserialize(self._db, recurse=recurse) return document.deserialize(self.db, recurse=recurse)
matches = self.search(*args, recurse=recurse, **kwargs) matches = self.search(*args, recurse=recurse, **kwargs)
if matches: if matches:
@ -56,7 +58,7 @@ class RecordTable(table.Table):
def search(self, *args, recurse: bool = False, **kwargs) -> List[Record]: def search(self, *args, recurse: bool = False, **kwargs) -> List[Record]:
results = super().search(*args, **kwargs) results = super().search(*args, **kwargs)
return [r.deserialize(self._db, recurse=recurse) for r in results] return [r.deserialize(self.db, recurse=recurse) for r in results]
def remove(self, document): def remove(self, document):
if document.doc_id: if document.doc_id:
@ -89,7 +91,10 @@ class GrungDB(TinyDB):
default_table_name = "Record" default_table_name = "Record"
_tables = {} _tables = {}
def __init__(self, *args, **kwargs): def __init__(self, path: Path, *args, **kwargs):
self.path = path
if kwargs["storage"] != MemoryStorage:
args = (path,) + args
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.create_table(Record) self.create_table(Record)
@ -122,8 +127,8 @@ class GrungDB(TinyDB):
return super().__getattr__(attr_name) return super().__getattr__(attr_name)
@classmethod @classmethod
def with_schema(cls, schema_module, *args, **kwargs): def with_schema(cls, schema_module, path: Path | None, *args, **kwargs):
db = GrungDB(*args, **kwargs) db = GrungDB(path=path, *args, **kwargs)
for name, obj in inspect.getmembers(schema_module): for name, obj in inspect.getmembers(schema_module):
if type(obj) == type and issubclass(obj, Record): if type(obj) == type and issubclass(obj, Record):
db.create_table(obj) db.create_table(obj)

View File

@ -4,6 +4,7 @@ from grung.types import (
DateTime, DateTime,
Dict, Dict,
Field, Field,
FilePointer,
Integer, Integer,
List, List,
Password, Password,
@ -49,6 +50,7 @@ class Album(Record):
Dict("credits"), Dict("credits"),
List("tracks"), List("tracks"),
BackReference("artist", Artist), BackReference("artist", Artist),
FilePointer("cover", extension=".jpg"),
] ]
@ -56,7 +58,4 @@ class Artist(User):
@classmethod @classmethod
def fields(cls): def fields(cls):
inherited = [f for f in super().fields() if f.name != "name"] inherited = [f for f in super().fields() if f.name != "name"]
return inherited + [ return inherited + [Field("name"), RecordDict("albums", Album)]
Field("name"),
RecordDict("albums", Album),
]

View File

@ -8,6 +8,7 @@ import typing
from collections import namedtuple from collections import namedtuple
from dataclasses import dataclass, field from dataclasses import dataclass, field
from datetime import datetime from datetime import datetime
from pathlib import Path
import nanoid import nanoid
from tinydb import TinyDB, where from tinydb import TinyDB, where
@ -35,7 +36,7 @@ class Field:
def after_insert(self, db: TinyDB, record: Record) -> None: def after_insert(self, db: TinyDB, record: Record) -> None:
pass pass
def serialize(self, value: value_type) -> str: def serialize(self, value: value_type, record: Record | None = None) -> str:
if value is not None: if value is not None:
return str(value) return str(value)
@ -57,7 +58,7 @@ class Dict(Field):
value_type: type = str value_type: type = str
default: dict = field(default_factory=lambda: {}) default: dict = field(default_factory=lambda: {})
def serialize(self, values: dict) -> Dict[(str, str)]: def serialize(self, values: dict, record: Record | None = None) -> Dict[(str, str)]:
return dict((key, str(value)) for key, value in values.items()) return dict((key, str(value)) for key, value in values.items())
def deserialize(self, values: dict, db: TinyDB, recurse: bool = False) -> Dict[(str, str)]: def deserialize(self, values: dict, db: TinyDB, recurse: bool = False) -> Dict[(str, str)]:
@ -69,7 +70,7 @@ class List(Field):
value_type: type = list value_type: type = list
default: list = field(default_factory=lambda: []) default: list = field(default_factory=lambda: [])
def serialize(self, values: list) -> Dict[(str, str)]: def serialize(self, values: list, record: Record | None = None) -> Dict[(str, str)]:
return values return values
def deserialize(self, values: list, db: TinyDB, recurse: bool = False) -> typing.List[str]: def deserialize(self, values: list, db: TinyDB, recurse: bool = False) -> typing.List[str]:
@ -81,7 +82,7 @@ class DateTime(Field):
value_type: datetime value_type: datetime
default: datetime = datetime.utcfromtimestamp(0) default: datetime = datetime.utcfromtimestamp(0)
def serialize(self, value: value_type) -> str: def serialize(self, value: value_type, record: Record | None = None) -> str:
return (value - datetime.utcfromtimestamp(0)).total_seconds() return (value - datetime.utcfromtimestamp(0)).total_seconds()
def deserialize(self, value: str, db: TinyDB, recurse: bool = False) -> value_type: def deserialize(self, value: str, db: TinyDB, recurse: bool = False) -> value_type:
@ -184,7 +185,7 @@ class Record(typing.Dict[(str, Field)]):
""" """
rec = {} rec = {}
for name, _field in self._metadata.fields.items(): for name, _field in self._metadata.fields.items():
rec[name] = _field.serialize(self[name]) if isinstance(_field, Field) else _field rec[name] = _field.serialize(self[name], record=self) if isinstance(_field, Field) else _field
return self.__class__(rec, doc_id=self.doc_id) return self.__class__(rec, doc_id=self.doc_id)
def deserialize(self, db, recurse: bool = True): def deserialize(self, db, recurse: bool = True):
@ -208,6 +209,10 @@ class Record(typing.Dict[(str, Field)]):
def reference(self): def reference(self):
return Pointer.reference(self) return Pointer.reference(self)
@property
def path(self):
return Path(self._metadata.table) / self[self._metadata.primary_key]
def __setattr__(self, key, value): def __setattr__(self, key, value):
if key in self: if key in self:
self[key] = value self[key] = value
@ -238,7 +243,7 @@ class Pointer(Field):
name: str = "" name: str = ""
value_type: type = Record value_type: type = Record
def serialize(self, value: value_type | str) -> str: def serialize(self, value: value_type | str, record: Record | None = None) -> str:
return Pointer.reference(value) return Pointer.reference(value)
def deserialize(self, value: str, db: TinyDB, recurse: bool = True) -> value_type: def deserialize(self, value: str, db: TinyDB, recurse: bool = True) -> value_type:
@ -277,6 +282,32 @@ class BackReference(Pointer):
pass pass
@dataclass
class FilePointer(Field):
"""
Write the contents of this field to disk and store the path in the db.
"""
name: str
value_type: type = bytes
extension: str = ".txt"
def relpath(self, record):
return Path(record._metadata.table) / record[record._metadata.primary_key] / f"{self.name}{self.extension}"
def deserialize(self, value: str, db: TinyDB, recurse: bool = True) -> value_type:
if not value:
return None
return (db.path / value).read_bytes()
def before_insert(self, value: value_type, db: TinyDB, record: Record) -> None:
relpath = self.relpath(record)
path = db.path / relpath
path.parent.mkdir(parents=True, exist_ok=True)
path.write_bytes(record[self.name])
record[self.name] = relpath
@dataclass @dataclass
class Collection(Field): class Collection(Field):
""" """
@ -286,7 +317,7 @@ class Collection(Field):
value_type: type = Record value_type: type = Record
default: typing.List[value_type] = field(default_factory=lambda: []) default: typing.List[value_type] = field(default_factory=lambda: [])
def serialize(self, values: typing.List[value_type]) -> typing.List[str]: def serialize(self, values: typing.List[value_type], record: Record | None = None) -> typing.List[str]:
return [Pointer.reference(val) for val in values] return [Pointer.reference(val) for val in values]
def deserialize(self, values: typing.List[str], db: TinyDB, recurse: bool = False) -> typing.List[value_type]: def deserialize(self, values: typing.List[str], db: TinyDB, recurse: bool = False) -> typing.List[value_type]:
@ -319,7 +350,9 @@ class RecordDict(Field):
value_type: type = Record value_type: type = Record
default: typing.Dict[(str, Record)] = field(default_factory=lambda: {}) default: typing.Dict[(str, Record)] = field(default_factory=lambda: {})
def serialize(self, values: typing.Dict[(str, value_type)]) -> typing.Dict[(str, str)]: def serialize(
self, values: typing.Dict[(str, value_type)], record: Record | None = None
) -> typing.Dict[(str, str)]:
return dict((key, Pointer.reference(val)) for (key, val) in values.items()) return dict((key, Pointer.reference(val)) for (key, val) in values.items())
def deserialize( def deserialize(

View File

@ -1,4 +1,6 @@
import tempfile
from datetime import datetime from datetime import datetime
from pathlib import Path
from pprint import pprint as print from pprint import pprint as print
from time import sleep from time import sleep
@ -13,7 +15,8 @@ from grung.exceptions import PointerReferenceError, UniqueConstraintError
@pytest.fixture @pytest.fixture
def db(): def db():
_db = GrungDB.with_schema(examples, storage=MemoryStorage) with tempfile.TemporaryDirectory() as path:
_db = GrungDB.with_schema(examples, path=Path(path), storage=MemoryStorage)
yield _db yield _db
print(_db) print(_db)
@ -153,10 +156,12 @@ def test_mapping(db):
name="The Impossible Kid", name="The Impossible Kid",
credits={"Produced By": "Aesop Rock", "Lyrics By": "Aesop Rock", "Puke in the MeowMix By": "Kirby"}, credits={"Produced By": "Aesop Rock", "Lyrics By": "Aesop Rock", "Puke in the MeowMix By": "Kirby"},
tracks=["Mystery Fish", "Rings", "Lotta Years", "Dorks"], tracks=["Mystery Fish", "Rings", "Lotta Years", "Dorks"],
cover=b"some jpg data",
) )
) )
assert album.credits["Produced By"] == "Aesop Rock" assert album.credits["Produced By"] == "Aesop Rock"
assert album.tracks[0] == "Mystery Fish" assert album.tracks[0] == "Mystery Fish"
assert album.cover == b"some jpg data"
aes = db.save( aes = db.save(
examples.Artist( examples.Artist(
@ -170,3 +175,6 @@ def test_mapping(db):
assert album.name in aes.albums assert album.name in aes.albums
assert aes.albums[album.name].uid == album.uid assert aes.albums[album.name].uid == album.uid
assert "Kirby" in aes.albums[album.name].credits.values() assert "Kirby" in aes.albums[album.name].credits.values()
location_on_disk = db.path / album._metadata.fields["cover"].relpath(album)
assert location_on_disk.read_bytes() == album.cover