-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathfetch_croissant.py
More file actions
103 lines (82 loc) · 3.31 KB
/
Copy pathfetch_croissant.py
File metadata and controls
103 lines (82 loc) · 3.31 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
"""Fetch (and optionally commit back) the Croissant JSON-LD for a HF dataset.
HF auto-generates Croissant metadata from a dataset's card YAML + configs. This
script pulls that auto-generated file, writes it locally, and optionally pushes
it back to the dataset repo root as `croissant.json` so it ships alongside the
data.
Usage:
python scripts/release/fetch_croissant.py # save locally
python scripts/release/fetch_croissant.py --commit # also push to HF
"""
from __future__ import annotations
import argparse
import json
import os
import sys
from pathlib import Path
import requests
DEFAULT_REPO = "causaldrivebench/CausalDriveBench"
DEFAULT_OUT = Path(
"/path/to/downloads/causaldrivebench_release/croissant.json"
)
def parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser(description=__doc__.splitlines()[0])
p.add_argument("--repo-id", default=DEFAULT_REPO,
help="HF dataset repo id (org/name).")
p.add_argument("--out", type=Path, default=DEFAULT_OUT,
help="Where to write the fetched Croissant JSON.")
p.add_argument("--commit", action="store_true",
help="Upload the fetched Croissant back to the repo root.")
p.add_argument("--commit-message",
default="add auto-generated Croissant metadata",
help="HF commit message when --commit is set.")
return p.parse_args()
def fetch(repo_id: str, token: str | None) -> dict:
url = f"https://huggingface.co/api/datasets/{repo_id}/croissant"
headers = {"Authorization": f"Bearer {token}"} if token else {}
r = requests.get(url, headers=headers, timeout=30)
if r.status_code != 200:
raise SystemExit(
f"HF returned {r.status_code} for {url}\n"
f"Body: {r.text[:500]}\n"
"Hint: dataset must be pushed first; private datasets need HF_TOKEN."
)
return r.json()
def commit_back(local_path: Path, repo_id: str, message: str, token: str | None) -> None:
try:
from huggingface_hub import HfApi
except ImportError:
raise SystemExit("Run: pip install huggingface_hub")
HfApi(token=token).upload_file(
path_or_fileobj=str(local_path),
path_in_repo="croissant.json",
repo_id=repo_id,
repo_type="dataset",
commit_message=message,
)
print(f"[commit] pushed croissant.json -> hf://datasets/{repo_id}")
def resolve_token() -> str | None:
"""HF_TOKEN env var, falling back to the cached `hf auth login` token."""
tok = os.environ.get("HF_TOKEN")
if tok:
return tok
try:
from huggingface_hub import get_token
return get_token()
except ImportError:
return None
def main() -> None:
args = parse_args()
token = resolve_token()
if not token:
raise SystemExit(
"No HF token found. Set HF_TOKEN or run `hf auth login`."
)
crs = fetch(args.repo_id, token)
args.out.parent.mkdir(parents=True, exist_ok=True)
args.out.write_text(json.dumps(crs, indent=2))
print(f"[fetch] wrote {args.out} ({len(crs.get('recordSet', []))} record sets, "
f"{len(crs.get('distribution', []))} distributions)")
if args.commit:
commit_back(args.out, args.repo_id, args.commit_message, token)
if __name__ == "__main__":
main()