diff --git a/src/grung/db.py b/src/grung/db.py index 452433f..62f8227 100644 --- a/src/grung/db.py +++ b/src/grung/db.py @@ -2,9 +2,11 @@ import inspect import re from functools import reduce from operator import ior +from pathlib import Path from typing import List from tinydb import Query, TinyDB, table +from tinydb.storages import MemoryStorage from tinydb.table import Document from grung.exceptions import UniqueConstraintError @@ -18,11 +20,11 @@ class RecordTable(table.Table): def __init__(self, name: str, db: TinyDB, document_class: Document = Record, **kwargs): self.document_class = document_class - self._db = db + self.db = db super().__init__(db.storage, name, **kwargs) def insert(self, document): - document.before_insert(self._db) + document.before_insert(self.db) doc = document.serialize() self._check_constraints(doc) @@ -31,8 +33,8 @@ class RecordTable(table.Table): else: last_insert_id = super().insert(dict(doc)) doc.doc_id = last_insert_id - doc.after_insert(self._db) - return doc.deserialize(self._db) + doc.after_insert(self.db) + return doc.deserialize(self.db) def get(self, *args, doc_id: int = None, recurse: bool = False, **kwargs): """ @@ -48,7 +50,7 @@ class RecordTable(table.Table): if doc_id: document = super().get(doc_id=doc_id) if document: - return document.deserialize(self._db, recurse=recurse) + return document.deserialize(self.db, recurse=recurse) matches = self.search(*args, recurse=recurse, **kwargs) if matches: @@ -56,7 +58,7 @@ class RecordTable(table.Table): def search(self, *args, recurse: bool = False, **kwargs) -> List[Record]: results = super().search(*args, **kwargs) - return [r.deserialize(self._db, recurse=recurse) for r in results] + return [r.deserialize(self.db, recurse=recurse) for r in results] def remove(self, document): if document.doc_id: @@ -89,7 +91,10 @@ class GrungDB(TinyDB): default_table_name = "Record" _tables = {} - def __init__(self, *args, **kwargs): + def __init__(self, path: Path, *args, **kwargs): + self.path = path + if kwargs["storage"] != MemoryStorage: + args = (path,) + args super().__init__(*args, **kwargs) self.create_table(Record) @@ -122,8 +127,8 @@ class GrungDB(TinyDB): return super().__getattr__(attr_name) @classmethod - def with_schema(cls, schema_module, *args, **kwargs): - db = GrungDB(*args, **kwargs) + def with_schema(cls, schema_module, path: Path | None, *args, **kwargs): + db = GrungDB(path=path, *args, **kwargs) for name, obj in inspect.getmembers(schema_module): if type(obj) == type and issubclass(obj, Record): db.create_table(obj) diff --git a/src/grung/examples.py b/src/grung/examples.py index 904e392..a833e9d 100644 --- a/src/grung/examples.py +++ b/src/grung/examples.py @@ -4,6 +4,7 @@ from grung.types import ( DateTime, Dict, Field, + FilePointer, Integer, List, Password, @@ -49,6 +50,7 @@ class Album(Record): Dict("credits"), List("tracks"), BackReference("artist", Artist), + FilePointer("cover", extension=".jpg"), ] @@ -56,7 +58,4 @@ class Artist(User): @classmethod def fields(cls): inherited = [f for f in super().fields() if f.name != "name"] - return inherited + [ - Field("name"), - RecordDict("albums", Album), - ] + return inherited + [Field("name"), RecordDict("albums", Album)] diff --git a/src/grung/types.py b/src/grung/types.py index 0af2fcd..285787e 100644 --- a/src/grung/types.py +++ b/src/grung/types.py @@ -8,6 +8,7 @@ import typing from collections import namedtuple from dataclasses import dataclass, field from datetime import datetime +from pathlib import Path import nanoid from tinydb import TinyDB, where @@ -35,7 +36,7 @@ class Field: def after_insert(self, db: TinyDB, record: Record) -> None: pass - def serialize(self, value: value_type) -> str: + def serialize(self, value: value_type, record: Record | None = None) -> str: if value is not None: return str(value) @@ -57,7 +58,7 @@ class Dict(Field): value_type: type = str default: dict = field(default_factory=lambda: {}) - def serialize(self, values: dict) -> Dict[(str, str)]: + def serialize(self, values: dict, record: Record | None = None) -> Dict[(str, str)]: return dict((key, str(value)) for key, value in values.items()) def deserialize(self, values: dict, db: TinyDB, recurse: bool = False) -> Dict[(str, str)]: @@ -69,7 +70,7 @@ class List(Field): value_type: type = list default: list = field(default_factory=lambda: []) - def serialize(self, values: list) -> Dict[(str, str)]: + def serialize(self, values: list, record: Record | None = None) -> Dict[(str, str)]: return values def deserialize(self, values: list, db: TinyDB, recurse: bool = False) -> typing.List[str]: @@ -81,7 +82,7 @@ class DateTime(Field): value_type: datetime default: datetime = datetime.utcfromtimestamp(0) - def serialize(self, value: value_type) -> str: + def serialize(self, value: value_type, record: Record | None = None) -> str: return (value - datetime.utcfromtimestamp(0)).total_seconds() def deserialize(self, value: str, db: TinyDB, recurse: bool = False) -> value_type: @@ -184,7 +185,7 @@ class Record(typing.Dict[(str, Field)]): """ rec = {} for name, _field in self._metadata.fields.items(): - rec[name] = _field.serialize(self[name]) if isinstance(_field, Field) else _field + rec[name] = _field.serialize(self[name], record=self) if isinstance(_field, Field) else _field return self.__class__(rec, doc_id=self.doc_id) def deserialize(self, db, recurse: bool = True): @@ -208,6 +209,10 @@ class Record(typing.Dict[(str, Field)]): def reference(self): return Pointer.reference(self) + @property + def path(self): + return Path(self._metadata.table) / self[self._metadata.primary_key] + def __setattr__(self, key, value): if key in self: self[key] = value @@ -238,7 +243,7 @@ class Pointer(Field): name: str = "" value_type: type = Record - def serialize(self, value: value_type | str) -> str: + def serialize(self, value: value_type | str, record: Record | None = None) -> str: return Pointer.reference(value) def deserialize(self, value: str, db: TinyDB, recurse: bool = True) -> value_type: @@ -277,6 +282,32 @@ class BackReference(Pointer): pass +@dataclass +class FilePointer(Field): + """ + Write the contents of this field to disk and store the path in the db. + """ + + name: str + value_type: type = bytes + extension: str = ".txt" + + def relpath(self, record): + return Path(record._metadata.table) / record[record._metadata.primary_key] / f"{self.name}{self.extension}" + + def deserialize(self, value: str, db: TinyDB, recurse: bool = True) -> value_type: + if not value: + return None + return (db.path / value).read_bytes() + + def before_insert(self, value: value_type, db: TinyDB, record: Record) -> None: + relpath = self.relpath(record) + path = db.path / relpath + path.parent.mkdir(parents=True, exist_ok=True) + path.write_bytes(record[self.name]) + record[self.name] = relpath + + @dataclass class Collection(Field): """ @@ -286,7 +317,7 @@ class Collection(Field): value_type: type = Record default: typing.List[value_type] = field(default_factory=lambda: []) - def serialize(self, values: typing.List[value_type]) -> typing.List[str]: + def serialize(self, values: typing.List[value_type], record: Record | None = None) -> typing.List[str]: return [Pointer.reference(val) for val in values] def deserialize(self, values: typing.List[str], db: TinyDB, recurse: bool = False) -> typing.List[value_type]: @@ -319,7 +350,9 @@ class RecordDict(Field): value_type: type = Record default: typing.Dict[(str, Record)] = field(default_factory=lambda: {}) - def serialize(self, values: typing.Dict[(str, value_type)]) -> typing.Dict[(str, str)]: + def serialize( + self, values: typing.Dict[(str, value_type)], record: Record | None = None + ) -> typing.Dict[(str, str)]: return dict((key, Pointer.reference(val)) for (key, val) in values.items()) def deserialize( diff --git a/test/test_db.py b/test/test_db.py index 054ee8b..35c6d40 100644 --- a/test/test_db.py +++ b/test/test_db.py @@ -1,4 +1,6 @@ +import tempfile from datetime import datetime +from pathlib import Path from pprint import pprint as print from time import sleep @@ -13,9 +15,10 @@ from grung.exceptions import PointerReferenceError, UniqueConstraintError @pytest.fixture def db(): - _db = GrungDB.with_schema(examples, storage=MemoryStorage) - yield _db - print(_db) + with tempfile.TemporaryDirectory() as path: + _db = GrungDB.with_schema(examples, path=Path(path), storage=MemoryStorage) + yield _db + print(_db) def test_crud(db): @@ -153,10 +156,12 @@ def test_mapping(db): name="The Impossible Kid", credits={"Produced By": "Aesop Rock", "Lyrics By": "Aesop Rock", "Puke in the MeowMix By": "Kirby"}, tracks=["Mystery Fish", "Rings", "Lotta Years", "Dorks"], + cover=b"some jpg data", ) ) assert album.credits["Produced By"] == "Aesop Rock" assert album.tracks[0] == "Mystery Fish" + assert album.cover == b"some jpg data" aes = db.save( examples.Artist( @@ -170,3 +175,6 @@ def test_mapping(db): assert album.name in aes.albums assert aes.albums[album.name].uid == album.uid assert "Kirby" in aes.albums[album.name].credits.values() + + location_on_disk = db.path / album._metadata.fields["cover"].relpath(album) + assert location_on_disk.read_bytes() == album.cover