Skip to content

Commit df8d73d

Browse files
Merge branch 'master' into fix/1429-cascade-part-part-renamed-fk
2 parents 9094a64 + 097d8c4 commit df8d73d

2 files changed

Lines changed: 189 additions & 127 deletions

File tree

src/datajoint/staged_insert.py

Lines changed: 91 additions & 127 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,19 @@
55
to object storage before finalizing the database insert.
66
"""
77

8-
import json
9-
import mimetypes
108
from contextlib import contextmanager
119
from datetime import datetime, timezone
12-
from typing import IO, Any
10+
from typing import IO, TYPE_CHECKING, Any
1311

1412
import fsspec
1513

14+
from .codecs import resolve_dtype
1615
from .errors import DataJointError
17-
from .storage import StorageBackend, build_object_path
16+
from .hash_registry import get_store_backend
17+
from .storage import build_object_path
18+
19+
if TYPE_CHECKING:
20+
from .storage import StorageBackend
1821

1922

2023
class StagedInsert:
@@ -30,15 +33,14 @@ class StagedInsert:
3033
staged.rec['subject_id'] = 123
3134
staged.rec['session_id'] = 45
3235
33-
# Create object storage directly
36+
# Write directly to object storage
3437
z = zarr.open(staged.store('raw_data', '.zarr'), mode='w', shape=(1000, 1000))
3538
z[:] = data
3639
37-
# Assign to record
38-
staged.rec['raw_data'] = z
39-
40-
# On successful exit: metadata computed, record inserted
41-
# On exception: storage cleaned up, no record inserted
40+
# On clean exit: metadata is computed and the row is inserted.
41+
# The caller does NOT assign anything to staged.rec[<object field>] —
42+
# the framework computes the metadata dict.
43+
# On exception: storage cleaned up, no row inserted.
4244
"""
4345

4446
def __init__(self, table):
@@ -50,8 +52,7 @@ def __init__(self, table):
5052
"""
5153
self._table = table
5254
self._rec: dict[str, Any] = {}
53-
self._staged_objects: dict[str, dict] = {} # field -> {path, ext, token}
54-
self._backend: StorageBackend | None = None
55+
self._staged_objects: dict[str, dict] = {} # field -> {relative_path, ext, token, store_name}
5556

5657
@property
5758
def rec(self) -> dict[str, Any]:
@@ -60,60 +61,57 @@ def rec(self) -> dict[str, Any]:
6061

6162
@property
6263
def fs(self) -> fsspec.AbstractFileSystem:
63-
"""Return fsspec filesystem for advanced operations."""
64-
self._ensure_backend()
65-
return self._backend.fs
64+
"""
65+
Return fsspec filesystem for the default store, for advanced operations.
6666
67-
def _ensure_backend(self):
68-
"""Ensure storage backend is initialized."""
69-
if self._backend is None:
70-
try:
71-
spec = self._table.connection._config.get_store_spec() # Uses stores.default
72-
self._backend = StorageBackend(spec)
73-
except DataJointError:
74-
raise DataJointError(
75-
"Storage is not configured. Set stores.default and stores.<name> settings in datajoint.json."
76-
)
77-
78-
def _get_storage_path(self, field: str, ext: str = "") -> str:
67+
For per-field access, prefer ``staged.store(field)`` or ``staged.open(field)`` —
68+
those route to the store resolved from the field's type spec.
7969
"""
80-
Get or create the storage path for a field.
70+
return self._default_backend().fs
8171

82-
Args:
83-
field: Name of the object attribute
84-
ext: Optional extension (e.g., ".zarr")
72+
def _default_backend(self):
73+
"""Return the StorageBackend for the default store, or raise a clear error."""
74+
try:
75+
return get_store_backend(None, config=self._table.connection._config)
76+
except DataJointError:
77+
raise DataJointError("Storage is not configured. Set stores.default and stores.<name> settings in datajoint.json.")
8578

86-
Returns:
87-
Full storage path
79+
def _resolve_field(self, field: str, ext: str) -> tuple[str, "StorageBackend"]:
8880
"""
89-
self._ensure_backend()
81+
Resolve a field to its (relative_path, backend), caching on first call.
9082
83+
Validates the field is an ``<object@>`` attribute and that the full
84+
primary key is set on ``staged.rec``.
85+
"""
9186
if field in self._staged_objects:
92-
return self._staged_objects[field]["full_path"]
87+
info = self._staged_objects[field]
88+
return info["relative_path"], self._field_backend(info["store_name"])
9389

94-
# Validate field is an object attribute
9590
if field not in self._table.heading:
9691
raise DataJointError(f"Attribute '{field}' not found in table heading")
9792

9893
attr = self._table.heading[field]
99-
# Check if this is an object Codec (has codec with "object" as name)
10094
if not (attr.codec and attr.codec.name == "object"):
10195
raise DataJointError(f"Attribute '{field}' is not an <object> type")
10296

103-
# Extract primary key from rec
10497
primary_key = {k: self._rec[k] for k in self._table.primary_key if k in self._rec}
10598
if len(primary_key) != len(self._table.primary_key):
10699
raise DataJointError(
107100
"Primary key values must be set in staged.rec before calling store() or open(). "
108101
f"Missing: {set(self._table.primary_key) - set(primary_key)}"
109102
)
110103

111-
# Get storage spec (uses stores.default)
112-
spec = self._table.connection._config.get_store_spec()
104+
# Resolve the store name from the field's type spec (e.g., <object@local> -> "local")
105+
_, _, store_name = resolve_dtype(f"<{attr.codec.name}>", store_name=attr.store)
106+
107+
config = self._table.connection._config
108+
try:
109+
spec = config.get_store_spec(store_name)
110+
except DataJointError:
111+
raise DataJointError("Storage is not configured. Set stores.default and stores.<name> settings in datajoint.json.")
113112
partition_pattern = spec.get("partition_pattern")
114113
token_length = spec.get("token_length", 8)
115114

116-
# Build storage path (relative - StorageBackend will add location prefix)
117115
relative_path, token = build_object_path(
118116
schema=self._table.database,
119117
table=self._table.class_name,
@@ -124,18 +122,25 @@ def _get_storage_path(self, field: str, ext: str = "") -> str:
124122
token_length=token_length,
125123
)
126124

127-
# Store staged object info (all paths are relative, backend adds location)
128125
self._staged_objects[field] = {
129126
"relative_path": relative_path,
130127
"ext": ext if ext else None,
131128
"token": token,
129+
"store_name": store_name,
132130
}
133131

134-
return relative_path
132+
return relative_path, self._field_backend(store_name)
133+
134+
def _field_backend(self, store_name: str | None):
135+
"""Return the StorageBackend for the named store."""
136+
try:
137+
return get_store_backend(store_name, config=self._table.connection._config)
138+
except DataJointError:
139+
raise DataJointError("Storage is not configured. Set stores.default and stores.<name> settings in datajoint.json.")
135140

136141
def store(self, field: str, ext: str = "") -> fsspec.FSMap:
137142
"""
138-
Get an FSMap store for direct writes to an object field.
143+
Get an FSMap for direct writes to an ``<object@>`` field.
139144
140145
Args:
141146
field: Name of the object attribute
@@ -144,12 +149,12 @@ def store(self, field: str, ext: str = "") -> fsspec.FSMap:
144149
Returns:
145150
fsspec.FSMap suitable for Zarr/xarray
146151
"""
147-
path = self._get_storage_path(field, ext)
148-
return self._backend.get_fsmap(path)
152+
relative_path, backend = self._resolve_field(field, ext)
153+
return backend.get_fsmap(relative_path)
149154

150155
def open(self, field: str, ext: str = "", mode: str = "wb") -> IO:
151156
"""
152-
Open a file for direct writes to an object field.
157+
Open a file for direct writes to an ``<object@>`` field.
153158
154159
Args:
155160
field: Name of the object attribute
@@ -159,127 +164,86 @@ def open(self, field: str, ext: str = "", mode: str = "wb") -> IO:
159164
Returns:
160165
File-like object for writing
161166
"""
162-
path = self._get_storage_path(field, ext)
163-
return self._backend.open(path, mode)
167+
relative_path, backend = self._resolve_field(field, ext)
168+
return backend.open(relative_path, mode)
164169

165170
def _compute_metadata(self, field: str) -> dict:
166171
"""
167-
Compute metadata for a staged object after writing is complete.
172+
Compute the canonical ``<object@>`` metadata dict for a staged write.
168173
169-
Args:
170-
field: Name of the object attribute
174+
The returned dict is structurally equal to what ``ObjectCodec.encode``
175+
would produce for the same content, modulo ``timestamp``.
171176
172-
Returns:
173-
JSON-serializable metadata dict
177+
Returns
178+
-------
179+
dict
180+
``{path, store, size, ext, is_dir, item_count, timestamp}``
174181
"""
175182
info = self._staged_objects[field]
176183
relative_path = info["relative_path"]
177184
ext = info["ext"]
185+
store_name = info["store_name"]
186+
backend = self._field_backend(store_name)
178187

179-
# Check if it's a directory (multiple files) or single file
180-
# _full_path adds the location prefix
181-
full_remote_path = self._backend._full_path(relative_path)
188+
full_remote_path = backend._full_path(relative_path)
182189

183190
try:
184-
is_dir = self._backend.fs.isdir(full_remote_path)
191+
is_dir = backend.fs.isdir(full_remote_path)
185192
except Exception:
186193
is_dir = False
187194

188195
if is_dir:
189-
# Calculate total size and file count
190196
total_size = 0
191197
item_count = 0
192-
files = []
193-
194-
for root, dirs, filenames in self._backend.fs.walk(full_remote_path):
198+
for root, _dirs, filenames in backend.fs.walk(full_remote_path):
195199
for filename in filenames:
196-
file_path = f"{root}/{filename}"
197200
try:
198-
file_size = self._backend.fs.size(file_path)
199-
rel_path = file_path[len(full_remote_path) :].lstrip("/")
200-
files.append({"path": rel_path, "size": file_size})
201-
total_size += file_size
201+
total_size += backend.fs.size(f"{root}/{filename}")
202202
item_count += 1
203203
except Exception:
204204
pass
205-
206-
# Create manifest
207-
manifest = {
208-
"files": files,
209-
"total_size": total_size,
210-
"item_count": item_count,
211-
"created": datetime.now(timezone.utc).isoformat(),
212-
}
213-
214-
# Write manifest alongside folder
215-
manifest_path = f"{relative_path}.manifest.json"
216-
self._backend.put_buffer(json.dumps(manifest, indent=2).encode(), manifest_path)
217-
218-
metadata = {
219-
"path": relative_path,
220-
"size": total_size,
221-
"hash": None,
222-
"ext": ext,
223-
"is_dir": True,
224-
"timestamp": datetime.now(timezone.utc).isoformat(),
225-
"item_count": item_count,
226-
}
205+
size = total_size
227206
else:
228-
# Single file
229207
try:
230-
size = self._backend.size(relative_path)
208+
size = backend.size(relative_path)
231209
except Exception:
232210
size = 0
233-
234-
metadata = {
235-
"path": relative_path,
236-
"size": size,
237-
"hash": None,
238-
"ext": ext,
239-
"is_dir": False,
240-
"timestamp": datetime.now(timezone.utc).isoformat(),
241-
}
242-
243-
# Add mime_type for files
244-
if ext:
245-
mime_type, _ = mimetypes.guess_type(f"file{ext}")
246-
if mime_type:
247-
metadata["mime_type"] = mime_type
248-
249-
return metadata
211+
item_count = None
212+
213+
return {
214+
"path": relative_path,
215+
"store": store_name,
216+
"size": size,
217+
"ext": ext,
218+
"is_dir": is_dir,
219+
"item_count": item_count,
220+
"timestamp": datetime.now(timezone.utc).isoformat(),
221+
}
250222

251223
def _finalize(self):
252224
"""
253-
Finalize the staged insert by computing metadata and inserting the record.
225+
Compute metadata for each staged object and insert the row.
254226
"""
255-
# Process each staged object
256227
for field in list(self._staged_objects.keys()):
257-
metadata = self._compute_metadata(field)
258-
# Store metadata dict in the record (ObjectType.encode handles it)
259-
self._rec[field] = metadata
260-
261-
# Insert the record
228+
self._rec[field] = self._compute_metadata(field)
262229
self._table.insert1(self._rec)
263230

264231
def _cleanup(self):
265232
"""
266-
Clean up staged objects on failure.
233+
Best-effort removal of staged objects on failure.
267234
"""
268-
if self._backend is None:
269-
return
270-
271235
for field, info in self._staged_objects.items():
272236
relative_path = info["relative_path"]
273237
try:
274-
# Check if it's a directory
275-
full_remote_path = self._backend._full_path(relative_path)
276-
if self._backend.fs.exists(full_remote_path):
277-
if self._backend.fs.isdir(full_remote_path):
278-
self._backend.remove_folder(relative_path)
238+
backend = self._field_backend(info["store_name"])
239+
full_remote_path = backend._full_path(relative_path)
240+
if backend.fs.exists(full_remote_path):
241+
if backend.fs.isdir(full_remote_path):
242+
backend.remove_folder(relative_path)
279243
else:
280-
self._backend.remove(relative_path)
244+
backend.remove(relative_path)
281245
except Exception:
282-
pass # Best effort cleanup
246+
pass # Best-effort cleanup
283247

284248

285249
@contextmanager
@@ -299,7 +263,7 @@ def staged_insert1(table):
299263
staged.rec['session_id'] = 45
300264
z = zarr.open(staged.store('raw_data', '.zarr'), mode='w')
301265
z[:] = data
302-
staged.rec['raw_data'] = z
266+
# Metadata for 'raw_data' is computed on clean exit; do not assign it here.
303267
"""
304268
staged = StagedInsert(table)
305269
try:

0 commit comments

Comments
 (0)