Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions invokeai/app/services/model_install/model_install_default.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,9 @@ def _write_install_marker(self, job: ModelInstallJob, status: Optional[InstallSt
marker = {
"version": INSTALL_MARKER_VERSION,
"source": str(job.source),
"access_token": (
job.source.access_token if isinstance(job.source, (HFModelSource, URLModelSource)) else None
),
"config_in": job.config_in.model_dump(),
"status": (status or job.status).value,
"updated_at": get_iso_timestamp(),
Expand Down Expand Up @@ -200,7 +203,13 @@ def _restore_incomplete_installs(self) -> None:
continue

try:
source_str = marker["source"]
source_str = marker.get("source")
if not isinstance(source_str, str):
raise ValueError("Missing source in install marker")
source = self._guess_source(source_str)
access_token = marker.get("access_token")
if isinstance(source, (HFModelSource, URLModelSource)) and isinstance(access_token, str):
source.access_token = access_token
if source_str in active_sources:
# This tmpdir belongs to an install already in progress; leave it alone.
self._logger.debug(f"Skipping restore for {source_str} - already being tracked")
Expand All @@ -210,7 +219,6 @@ def _restore_incomplete_installs(self) -> None:
self._safe_rmtree(tmpdir, self._logger)
continue
seen_sources.add(source_str)
source = self._guess_source(source_str)
except Exception as e:
self._logger.warning(f"Skipping install marker in {tmpdir}: {e}")
continue
Expand Down
63 changes: 63 additions & 0 deletions tests/app/services/model_install/test_model_install.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import gc
import platform
import shutil
import uuid
from pathlib import Path
from typing import Any, Dict
Expand All @@ -23,6 +24,7 @@
)
from invokeai.app.services.model_install import (
HFModelSource,
ModelInstallService,
ModelInstallServiceBase,
)
from invokeai.app.services.model_install.model_install_common import (
Expand Down Expand Up @@ -343,6 +345,67 @@ def test_huggingface_repo_id(mm2_installer: ModelInstallServiceBase, mm2_app_con
assert job.total_bytes == sum(x["total_bytes"] for x in downloading_events[-1].parts)


def test_restore_paused_hf_install_preserves_access_token(
mm2_installer: ModelInstallServiceBase,
mm2_app_config: InvokeAIAppConfig,
mm2_download_queue,
mm2_session,
monkeypatch: pytest.MonkeyPatch,
) -> None:
assert isinstance(mm2_installer, ModelInstallService)

access_token = "hf_test_access_token"
tmpdir = mm2_app_config.models_path / f"tmpinstall_resume_token_{uuid.uuid4().hex}"
tmpdir.mkdir(parents=True, exist_ok=True)

try:
paused_job = ModelInstallJob(
id=99999,
source=HFModelSource(
repo_id="stabilityai/sdxl-turbo",
variant=ModelRepoVariant.Default,
access_token=access_token,
),
config_in=ModelRecordChanges(),
local_path=tmpdir,
)
paused_job._install_tmpdir = tmpdir
paused_job.status = InstallStatus.PAUSED

mm2_installer._write_install_marker(paused_job, status=InstallStatus.PAUSED)

marker = mm2_installer._read_install_marker(tmpdir)
assert marker is not None
assert marker["access_token"] == access_token

restored_installer = ModelInstallService(
app_config=mm2_app_config,
record_store=mm2_installer.record_store,
download_queue=mm2_download_queue,
session=mm2_session,
)
restored_installer._restore_incomplete_installs()
restored_jobs = restored_installer.list_jobs()
assert len(restored_jobs) == 1

restored_job = restored_jobs[0]
assert restored_job.paused
assert isinstance(restored_job.source, HFModelSource)
assert restored_job.source.access_token == access_token

captured: dict[str, str | None] = {}

def _capture_resume(job: ModelInstallJob) -> None:
assert isinstance(job.source, HFModelSource)
captured["access_token"] = job.source.access_token

monkeypatch.setattr(restored_installer, "_resume_remote_download", _capture_resume)
restored_installer.resume_job(restored_job)
assert captured["access_token"] == access_token
finally:
shutil.rmtree(tmpdir, ignore_errors=True)


def test_404_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:
source = URLModelSource(url=Url("https://test.com/missing_model.safetensors"))
job = mm2_installer.import_model(source)
Expand Down