Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions changelog.d/gcs-runtime-datasets.fixed.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Restore runtime loading for `gs://` dataset URIs and materialize remote
datasets before handing them to country package microsimulations.
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ dependencies = [
"requests>=2.31.0",
"psutil>=5.9.0",
"packaging>=23.0",
"google-cloud-storage>=3.1.0,<4.0.0",
"diskcache>=5.6.3,<6.0.0",
]

[project.scripts]
Expand Down
105 changes: 105 additions & 0 deletions src/policyengine/provenance/dataset_sources.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
"""Runtime dataset source materialization."""

from __future__ import annotations

from dataclasses import dataclass
from typing import Optional

from policyengine.utils.google_cloud_bucket import download_file_from_gcs


@dataclass(frozen=True)
class GCSArtifactReference:
bucket: str
path: str
version: Optional[str] = None


@dataclass(frozen=True)
class HFArtifactReference:
repo_id: str
path: str
version: Optional[str] = None


def _select_version(
uri_version: Optional[str],
requested_version: Optional[str],
) -> Optional[str]:
if (
uri_version is not None
and requested_version is not None
and uri_version != requested_version
):
raise ValueError(
"Conflicting dataset versions: "
f"URI requests {uri_version!r} but version is {requested_version!r}"
)
return uri_version or requested_version


def parse_gs_uri(uri: str) -> GCSArtifactReference:
if not uri.startswith("gs://"):
raise ValueError(f"Invalid GCS dataset URI: {uri!r}")

path_with_bucket, version = (
uri[5:].rsplit("@", maxsplit=1) if "@" in uri[5:] else (uri[5:], None)
)
bucket, separator, path = path_with_bucket.partition("/")
if not bucket or not separator or not path:
raise ValueError(
"Invalid GCS dataset URI. Expected format "
f"'gs://bucket/path/to/file[@version]', got {uri!r}."
)
return GCSArtifactReference(bucket=bucket, path=path, version=version)


def parse_hf_uri(uri: str) -> HFArtifactReference:
if not uri.startswith("hf://"):
raise ValueError(f"Invalid Hugging Face dataset URI: {uri!r}")

path_with_repo, version = (
uri[5:].rsplit("@", maxsplit=1) if "@" in uri[5:] else (uri[5:], None)
)
parts = path_with_repo.split("/", maxsplit=2)
if len(parts) != 3 or not all(parts):
raise ValueError(
"Invalid Hugging Face dataset URI. Expected format "
f"'hf://owner/repo/path/to/file[@revision]', got {uri!r}."
)
return HFArtifactReference(
repo_id=f"{parts[0]}/{parts[1]}",
path=parts[2],
version=version,
)


def materialize_dataset_source(
dataset_source: str,
*,
version: Optional[str] = None,
) -> str:
"""Return a local file path for supported remote dataset URIs."""

if dataset_source.startswith("gs://"):
reference = parse_gs_uri(dataset_source)
local_path, _ = download_file_from_gcs(
reference.bucket,
reference.path,
version=_select_version(reference.version, version),
)
return local_path

if dataset_source.startswith("hf://"):
from policyengine_core.tools.hugging_face import (
download_huggingface_dataset,
)

reference = parse_hf_uri(dataset_source)
return download_huggingface_dataset(
reference.repo_id,
reference.path,
version=_select_version(reference.version, version),
)

return dataset_source
4 changes: 3 additions & 1 deletion src/policyengine/tax_benefit_models/uk/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from pydantic import ConfigDict

from policyengine.core import Dataset, YearData
from policyengine.provenance.dataset_sources import materialize_dataset_source
from policyengine.provenance.manifest import (
dataset_logical_name,
resolve_dataset_reference,
Expand Down Expand Up @@ -111,9 +112,10 @@ def create_datasets(
for dataset in datasets:
resolved_dataset = resolve_dataset_reference("uk", dataset)
dataset_stem = dataset_logical_name(resolved_dataset)
runtime_dataset = materialize_dataset_source(resolved_dataset)
from policyengine_uk import Microsimulation

sim = Microsimulation(dataset=resolved_dataset)
sim = Microsimulation(dataset=runtime_dataset)
for year in years:
year_dataset = sim.dataset[year]

Expand Down
16 changes: 9 additions & 7 deletions src/policyengine/tax_benefit_models/uk/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from microdf import MicroDataFrame

from policyengine.core import TaxBenefitModel
from policyengine.provenance.dataset_sources import materialize_dataset_source
from policyengine.provenance.manifest import (
dataset_logical_name,
resolve_local_managed_dataset_source,
Expand Down Expand Up @@ -288,21 +289,22 @@ def managed_microsimulation(
allow_unmanaged and dataset is not None and "://" in dataset
),
)
runtime_dataset = dataset_source
if isinstance(dataset_source, str) and "hf://" not in dataset_source:
runtime_dataset_source = materialize_dataset_source(dataset_source)
runtime_dataset = runtime_dataset_source
if isinstance(runtime_dataset_source, str) and "://" not in runtime_dataset_source:
from policyengine_uk.data.dataset_schema import (
UKMultiYearDataset,
UKSingleYearDataset,
)

if UKMultiYearDataset.validate_file_path(dataset_source, False):
runtime_dataset = UKMultiYearDataset(dataset_source)
elif UKSingleYearDataset.validate_file_path(dataset_source, False):
runtime_dataset = UKSingleYearDataset(dataset_source)
if UKMultiYearDataset.validate_file_path(runtime_dataset_source, False):
runtime_dataset = UKMultiYearDataset(runtime_dataset_source)
elif UKSingleYearDataset.validate_file_path(runtime_dataset_source, False):
runtime_dataset = UKSingleYearDataset(runtime_dataset_source)
microsim = Microsimulation(dataset=runtime_dataset, **kwargs)
microsim.policyengine_bundle = _managed_release_bundle(
dataset_uri,
dataset_source,
runtime_dataset_source,
)
return microsim

Expand Down
4 changes: 3 additions & 1 deletion src/policyengine/tax_benefit_models/us/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from pydantic import ConfigDict, Field

from policyengine.core import Dataset, YearData
from policyengine.provenance.dataset_sources import materialize_dataset_source
from policyengine.provenance.manifest import (
dataset_logical_name,
get_release_manifest,
Expand Down Expand Up @@ -285,7 +286,8 @@ def create_datasets(
for dataset in datasets:
resolved_dataset = resolve_dataset_reference("us", dataset)
dataset_stem = dataset_logical_name(resolved_dataset)
sim = Microsimulation(dataset=resolved_dataset)
runtime_dataset = materialize_dataset_source(resolved_dataset)
sim = Microsimulation(dataset=runtime_dataset)

for year in years:
# Get all input variables from the simulation
Expand Down
6 changes: 4 additions & 2 deletions src/policyengine/tax_benefit_models/us/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from microdf import MicroDataFrame

from policyengine.core import TaxBenefitModel
from policyengine.provenance.dataset_sources import materialize_dataset_source
from policyengine.provenance.manifest import (
dataset_logical_name,
resolve_local_managed_dataset_source,
Expand Down Expand Up @@ -438,10 +439,11 @@ def managed_microsimulation(
allow_unmanaged and dataset is not None and "://" in dataset
),
)
microsim = Microsimulation(dataset=dataset_source, **kwargs)
runtime_dataset_source = materialize_dataset_source(dataset_source)
microsim = Microsimulation(dataset=runtime_dataset_source, **kwargs)
microsim.policyengine_bundle = _managed_release_bundle(
dataset_uri,
dataset_source,
runtime_dataset_source,
)
return microsim

Expand Down
1 change: 1 addition & 0 deletions src/policyengine/utils/data/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Dataset download helpers."""
105 changes: 105 additions & 0 deletions src/policyengine/utils/data/caching_google_storage_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
"""Disk-cached Google Cloud Storage downloads."""

from __future__ import annotations

import logging
import os
import tempfile
from contextlib import AbstractContextManager
from pathlib import Path
from typing import Optional

import diskcache

from .version_aware_storage_client import VersionAwareStorageClient

logger = logging.getLogger(__name__)


def _atomic_write(target: Path, content: bytes) -> None:
target.parent.mkdir(parents=True, exist_ok=True)
temp_path = None
try:
with tempfile.NamedTemporaryFile(
dir=target.parent,
delete=False,
) as temp_file:
temp_path = Path(temp_file.name)
temp_file.write(content)
os.replace(temp_path, target)
finally:
if temp_path is not None and temp_path.exists():
temp_path.unlink()


class CachingGoogleStorageClient(AbstractContextManager):
"""Download GCS objects through a CRC-keyed disk cache."""

def __init__(self) -> None:
self.client = VersionAwareStorageClient()
self.cache = diskcache.Cache()

@staticmethod
def _data_key(bucket: str, key: str, version: Optional[str] = None) -> str:
return f"{bucket}.{key}.{version}.data"

@staticmethod
def _crc_key(bucket: str, key: str, version: Optional[str] = None) -> str:
return f"{bucket}.{key}.{version}.crc"

def download(
self,
bucket: str,
key: str,
target: Path,
version: Optional[str] = None,
return_version: bool = False,
) -> Optional[str]:
if version is None:
version = self.client.latest_metadata_version(bucket, key)
logger.warning(
"No version specified for %s/%s; using latest metadata version %s",
bucket,
key,
version,
)

self.sync(bucket, key, version)
data = self.cache.get(self._data_key(bucket, key, version))
if isinstance(data, bytes):
_atomic_write(target, data)
return version if return_version else None

raise TypeError(
f"Expected cached data for {bucket}/{key}@{version} to be bytes"
)

def sync(
self,
bucket: str,
key: str,
version: Optional[str] = None,
) -> None:
crc = self.client.crc32c(bucket, key, version=version)
if crc is None:
raise FileNotFoundError(f"Unable to find gs://{bucket}/{key}")

data_key = self._data_key(bucket, key, version)
crc_key = self._crc_key(bucket, key, version)
if self.cache.get(crc_key, default=None) == crc:
return

content, downloaded_crc = self.client.download(bucket, key, version=version)
with self.cache as cache:
cache.set(data_key, content)
cache.set(crc_key, downloaded_crc)

def clear(self) -> None:
self.cache.clear()

def __enter__(self) -> CachingGoogleStorageClient:
return self

def __exit__(self, exc_type, exc_value, traceback):
self.clear()
return None
Loading
Loading