From 5eb3af3c1d3b88dc4a11587bbe71469e2ee327f5 Mon Sep 17 00:00:00 2001 From: Chandrasekharan M Date: Sun, 24 May 2026 01:15:32 +0530 Subject: [PATCH 01/25] feat(migration): SDK subpackage for org-to-org data migration (v1, adapter) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds `unstract.migration` — a CLI + library that drives org-to-org data migration between Unstract deployments via the existing Platform API surface. v1 ships the adapter phase end-to-end as the reference impl; remaining phases (connector, tag, custom_tool, workflow, tool_instance, workflow_endpoint) follow in subsequent commits. Design highlights: - Two admin Platform API keys in, populated target out. No bundle, no Django mgmt command, no out-of-band secret step. - Idempotency via in-memory remap (per run) + name-based GET against target (across runs). No state file — target IS the state. - Fresh target UUIDs on every create; embedded UUID references remapped in-memory via `walker.remap_uuids` pre-POST. - Secrets carried verbatim from source GET to target POST (same surface the FE already consumes when an admin opens an adapter card). Package layout: - `client.py` thin Platform API wrapper (one per OrgEndpoint) - `context.py` OrgEndpoint, MigrationOptions, MigrationContext, RemapTable - `walker.py` `remap_uuids` JSON walker - `report.py` rich-rendered MigrationReport + plain-text fallback - `phases/` base.Phase ABC + adapter.AdapterPhase (reference impl) - `orchestrator.py` top-level `migrate()` + phase order - `cli.py` click CLI: `unstract-migrate migrate ...` CLI: - Entry point `unstract-migrate` (also `python -m unstract.migration`) - Platform keys via flags or env (UNSTRACT_SRC_PLATFORM_KEY / UNSTRACT_TGT_PLATFORM_KEY) to keep them out of shell history - `--api-prefix` overrides PATH_PREFIX (default api/v1 to match OSS docker compose; cloud/on-prem set as needed) - `--dry-run`, `--include`, `--exclude`, `--on-name-conflict adopt|abort` Audit: - Per-line logs include source UUID + target UUID for every adopted/created entity (`src=... -> tgt=...`) - Final report renders a source -> target UUID map table for traceability Deps gated behind optional `[migration]` extra — core SDK consumers unaffected. Tests: 13 unit tests (RemapTable, remap_uuids, AdapterPhase happy path / idempotency / dry-run / abort). Integration smoke verified locally against docker compose: 9 adapters migrated source -> target, re-run reports 0 created / 9 adopted. --- pyproject.toml | 9 ++ src/unstract/migration/__init__.py | 25 ++++ src/unstract/migration/__main__.py | 6 + src/unstract/migration/cli.py | 138 +++++++++++++++++++ src/unstract/migration/client.py | 101 ++++++++++++++ src/unstract/migration/context.py | 98 ++++++++++++++ src/unstract/migration/exceptions.py | 22 +++ src/unstract/migration/orchestrator.py | 66 +++++++++ src/unstract/migration/phases/__init__.py | 13 ++ src/unstract/migration/phases/adapter.py | 106 +++++++++++++++ src/unstract/migration/phases/base.py | 25 ++++ src/unstract/migration/report.py | 111 ++++++++++++++++ src/unstract/migration/walker.py | 32 +++++ tests/migration/__init__.py | 0 tests/migration/test_adapter_phase.py | 155 ++++++++++++++++++++++ tests/migration/test_remap_table.py | 36 +++++ tests/migration/test_walker.py | 54 ++++++++ uv.lock | 57 +++++++- 18 files changed, 1053 insertions(+), 1 deletion(-) create mode 100644 src/unstract/migration/__init__.py create mode 100644 src/unstract/migration/__main__.py create mode 100644 src/unstract/migration/cli.py create mode 100644 src/unstract/migration/client.py create mode 100644 src/unstract/migration/context.py create mode 100644 src/unstract/migration/exceptions.py create mode 100644 src/unstract/migration/orchestrator.py create mode 100644 src/unstract/migration/phases/__init__.py create mode 100644 src/unstract/migration/phases/adapter.py create mode 100644 src/unstract/migration/phases/base.py create mode 100644 src/unstract/migration/report.py create mode 100644 src/unstract/migration/walker.py create mode 100644 tests/migration/__init__.py create mode 100644 tests/migration/test_adapter_phase.py create mode 100644 tests/migration/test_remap_table.py create mode 100644 tests/migration/test_walker.py diff --git a/pyproject.toml b/pyproject.toml index c45a8a9..0050c98 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,6 +25,15 @@ classifiers = [ "Topic :: Software Development :: Libraries :: Python Modules", ] +[project.optional-dependencies] +migration = [ + "click>=8.1", + "rich>=13.7", +] + +[project.scripts] +unstract-migrate = "unstract.migration.cli:main" + [build-system] requires = ["hatchling"] build-backend = "hatchling.build" diff --git a/src/unstract/migration/__init__.py b/src/unstract/migration/__init__.py new file mode 100644 index 0000000..728b5b6 --- /dev/null +++ b/src/unstract/migration/__init__.py @@ -0,0 +1,25 @@ +"""Org-to-org data migration over the Platform API. + +Migrates configured resources (adapters, connectors, custom tools, workflows, +etc.) from one Unstract org to another using two admin-issued Platform API +keys. The target deployment is the persistent state — re-runs reconcile +against existing target rows by natural key. +""" + +from unstract.migration.context import ( + MigrationContext, + MigrationOptions, + OrgEndpoint, + RemapTable, +) +from unstract.migration.orchestrator import migrate +from unstract.migration.report import MigrationReport + +__all__ = [ + "MigrationContext", + "MigrationOptions", + "MigrationReport", + "OrgEndpoint", + "RemapTable", + "migrate", +] diff --git a/src/unstract/migration/__main__.py b/src/unstract/migration/__main__.py new file mode 100644 index 0000000..c9d3fd2 --- /dev/null +++ b/src/unstract/migration/__main__.py @@ -0,0 +1,6 @@ +"""Entry point: ``python -m unstract.migration``.""" + +from unstract.migration.cli import main + +if __name__ == "__main__": + main() diff --git a/src/unstract/migration/cli.py b/src/unstract/migration/cli.py new file mode 100644 index 0000000..c03311e --- /dev/null +++ b/src/unstract/migration/cli.py @@ -0,0 +1,138 @@ +"""Click-based CLI for ``unstract.migration``. + +Single ``migrate`` command. Platform keys can be passed via flags +(``--source-key`` / ``--target-key``) or env vars +(``UNSTRACT_SRC_PLATFORM_KEY`` / ``UNSTRACT_TGT_PLATFORM_KEY``) — env vars +are preferred so the key never lands in shell history. +""" + +from __future__ import annotations + +import logging +import sys +from typing import Any + +import click + +from unstract.migration.context import MigrationOptions, OrgEndpoint +from unstract.migration.exceptions import MigrationError +from unstract.migration.orchestrator import migrate as run_migrate + + +def _configure_logging(verbose: bool) -> None: + level = logging.DEBUG if verbose else logging.INFO + logging.basicConfig( + level=level, + format="%(asctime)s %(levelname)-7s %(name)s: %(message)s", + datefmt="%H:%M:%S", + ) + + +def _split_csv(value: str | None) -> tuple[str, ...] | None: + if not value: + return None + return tuple(p.strip() for p in value.split(",") if p.strip()) + + +@click.group() +def cli() -> None: + """Org-to-org data migration over the Platform API.""" + + +@cli.command("migrate") +@click.option("--source-url", required=True, help="Base URL of the source deployment (e.g. https://us.unstract.com)") +@click.option("--source-org", required=True, help="Source organization_id (slug in the URL path)") +@click.option( + "--source-key", + envvar="UNSTRACT_SRC_PLATFORM_KEY", + required=True, + help="Source admin's Platform API key (or env UNSTRACT_SRC_PLATFORM_KEY)", +) +@click.option("--target-url", required=True, help="Base URL of the target deployment") +@click.option("--target-org", required=True, help="Target organization_id (slug in the URL path)") +@click.option( + "--target-key", + envvar="UNSTRACT_TGT_PLATFORM_KEY", + required=True, + help="Target admin's Platform API key (or env UNSTRACT_TGT_PLATFORM_KEY)", +) +@click.option("--dry-run", is_flag=True, help="Plan only — do not POST anything to target") +@click.option( + "--include", + default=None, + help="Comma-separated phase names to include (default: all)", +) +@click.option( + "--exclude", + default=None, + help="Comma-separated phase names to exclude", +) +@click.option( + "--on-name-conflict", + type=click.Choice(["adopt", "abort"]), + default="adopt", + show_default=True, + help="What to do when a like-named entity exists in target", +) +@click.option( + "--api-prefix", + default="api/v1", + show_default=True, + help="Backend URL prefix (matches deployment's PATH_PREFIX env)", +) +@click.option("-v", "--verbose", is_flag=True, help="Debug logging") +def migrate_cmd( + source_url: str, + source_org: str, + source_key: str, + target_url: str, + target_org: str, + target_key: str, + dry_run: bool, + include: str | None, + exclude: str | None, + on_name_conflict: str, + api_prefix: str, + verbose: bool, +) -> None: + """Migrate configured resources from one org to another.""" + _configure_logging(verbose) + + options = MigrationOptions( + dry_run=dry_run, + include=_split_csv(include), + exclude=_split_csv(exclude) or (), + on_name_conflict=on_name_conflict, + verbose=verbose, + ) + + source = OrgEndpoint( + base_url=source_url, + organization_id=source_org, + platform_key=source_key, + api_path_prefix=api_prefix, + ) + target = OrgEndpoint( + base_url=target_url, + organization_id=target_org, + platform_key=target_key, + api_path_prefix=api_prefix, + ) + + try: + report = run_migrate(source, target, options) + except MigrationError as e: + click.echo(f"Migration failed: {e}", err=True) + sys.exit(2) + + click.echo(report.render()) + if report.aborted or any(p.failed for p in report.phases): + sys.exit(1) + + +def main(argv: list[str] | None = None) -> Any: + return cli(args=argv, standalone_mode=True) + + +if __name__ == "__main__": + main() diff --git a/src/unstract/migration/client.py b/src/unstract/migration/client.py new file mode 100644 index 0000000..4df7185 --- /dev/null +++ b/src/unstract/migration/client.py @@ -0,0 +1,101 @@ +"""Thin Platform API client for the migration subpackage. + +One ``PlatformClient`` instance per ``OrgEndpoint``. Methods are entity- +scoped (``list_adapters``, ``create_adapter``, ...) so the call sites in +phases read like business logic, not HTTP plumbing. + +URL shape (see ``backend/middleware/organization_middleware.py``): + {base_url}/api/v2/unstract/{org_id}// + +Auth: ``Authorization: Bearer `` (see +``backend/account_v2/custom_auth_middleware.py``). +""" + +from __future__ import annotations + +import logging +from typing import Any + +import requests + +from unstract.migration.context import OrgEndpoint +from unstract.migration.exceptions import PlatformAPIError + +logger = logging.getLogger(__name__) + +DEFAULT_TIMEOUT = 60 + + +class PlatformClient: + """HTTP client scoped to a single org via its Platform API key.""" + + def __init__(self, endpoint: OrgEndpoint, timeout: int = DEFAULT_TIMEOUT, verify: bool = True): + self.endpoint = endpoint + self.timeout = timeout + self.verify = verify + self._session = requests.Session() + self._session.headers.update( + { + "Authorization": f"Bearer {endpoint.platform_key}", + "Accept": "application/json", + } + ) + + def _url(self, path: str) -> str: + base = self.endpoint.base_url.rstrip("/") + api_prefix = self.endpoint.api_path_prefix.strip("/") + prefix = f"/{api_prefix}/unstract/{self.endpoint.organization_id}/" + return base + prefix + path.lstrip("/") + + def _request( + self, + method: str, + path: str, + *, + params: dict[str, Any] | None = None, + json: Any = None, + ) -> Any: + url = self._url(path) + # Redact secrets from logs: only entity path + method, never body. + logger.debug("%s %s", method, url) + resp = self._session.request( + method, + url, + params=params, + json=json, + timeout=self.timeout, + verify=self.verify, + ) + if not 200 <= resp.status_code < 300: + raise PlatformAPIError( + f"{method} {path} returned {resp.status_code}", + status_code=resp.status_code, + body=resp.text[:2000], + ) + if resp.status_code == 204 or not resp.content: + return None + return resp.json() + + # ----- adapters ----- + + def list_adapters( + self, + *, + name: str | None = None, + adapter_type: str | None = None, + ) -> list[dict[str, Any]]: + """List adapters in this org, optionally filtered by name and/or type.""" + params: dict[str, Any] = {} + if name is not None: + params["adapter_name"] = name + if adapter_type is not None: + params["adapter_type"] = adapter_type + result = self._request("GET", "adapter/", params=params) + # DRF ModelViewSet.list returns a bare list (no pagination on this endpoint). + return result if isinstance(result, list) else result.get("results", []) + + def get_adapter(self, adapter_pk: str) -> dict[str, Any]: + return self._request("GET", f"adapter/{adapter_pk}/") + + def create_adapter(self, payload: dict[str, Any]) -> dict[str, Any]: + return self._request("POST", "adapter/", json=payload) diff --git a/src/unstract/migration/context.py b/src/unstract/migration/context.py new file mode 100644 index 0000000..8d84f1a --- /dev/null +++ b/src/unstract/migration/context.py @@ -0,0 +1,98 @@ +"""Shared state passed between migration phases. + +Three top-level types: + +- ``OrgEndpoint`` — base URL + organization_id + Platform API key for one org. +- ``MigrationOptions`` — run flags (dry-run, include/exclude, name-conflict). +- ``MigrationContext`` — bundles source/target clients, options, and the + per-run ``RemapTable``. + +``RemapTable`` lives here too because every phase touches it. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from unstract.migration.client import PlatformClient + + +@dataclass(frozen=True) +class OrgEndpoint: + """One end of a migration: where to talk to and who to talk as. + + ``organization_id`` is the slug embedded in the URL path (see + ``OrganizationMiddleware`` regex ``/api/(v1|v2)/unstract//...``). + ``platform_key`` is the bearer UUID issued by an org admin. + + ``api_path_prefix`` defaults to ``api/v1`` to match the OSS docker + compose ``PATH_PREFIX`` env. Cloud / on-prem envs that mount on a + different prefix can override (e.g. ``api/v2``). + """ + + base_url: str + organization_id: str + platform_key: str + api_path_prefix: str = "api/v1" + + +@dataclass +class MigrationOptions: + """Per-run flags for ``migrate()``.""" + + dry_run: bool = False + include: tuple[str, ...] | None = None + exclude: tuple[str, ...] = () + on_name_conflict: str = "adopt" # "adopt" | "abort" + verbose: bool = False + + def includes(self, phase_name: str) -> bool: + if self.include is not None and phase_name not in self.include: + return False + return phase_name not in self.exclude + + +class RemapTable: + """Maps source UUID -> target UUID, scoped per entity type. + + Built up in dependency order; consumed by the JSON walker before POST. + ``resolve_any`` lets the walker look up a UUID without knowing its + entity type — necessary because embedded references in JSON payloads + don't always carry an entity hint. + """ + + def __init__(self) -> None: + self._table: dict[str, dict[str, str]] = {} + + def record(self, entity: str, src_uuid: str, tgt_uuid: str) -> None: + self._table.setdefault(entity, {})[src_uuid] = tgt_uuid + + def resolve(self, entity: str, src_uuid: str) -> str | None: + return self._table.get(entity, {}).get(src_uuid) + + def resolve_any(self, src_uuid: str) -> str | None: + for mapping in self._table.values(): + hit = mapping.get(src_uuid) + if hit is not None: + return hit + return None + + def snapshot(self) -> dict[str, dict[str, str]]: + """Read-only snapshot for the post-run report.""" + return {entity: dict(m) for entity, m in self._table.items()} + + +@dataclass +class MigrationContext: + """Shared state for one ``migrate()`` invocation. + + Phases hold a reference to this and call ``ctx.source`` / ``ctx.target`` + to drive HTTP, ``ctx.remap`` to record UUID mappings. + """ + + source: PlatformClient + target: PlatformClient + options: MigrationOptions + remap: RemapTable = field(default_factory=RemapTable) diff --git a/src/unstract/migration/exceptions.py b/src/unstract/migration/exceptions.py new file mode 100644 index 0000000..572d9e5 --- /dev/null +++ b/src/unstract/migration/exceptions.py @@ -0,0 +1,22 @@ +"""Exceptions raised by the migration subpackage.""" + + +class MigrationError(Exception): + """Base class for all migration errors.""" + + +class PlatformAPIError(MigrationError): + """Raised when the Platform API returns a non-2xx response we can't recover from.""" + + def __init__(self, message: str, status_code: int | None = None, body: str | None = None): + super().__init__(message) + self.status_code = status_code + self.body = body + + +class NameConflictError(MigrationError): + """Raised when ``on_name_conflict='abort'`` and the target has a like-named entity.""" + + +class DependencyMissingError(MigrationError): + """Raised when a phase references a source UUID that no prior phase has mapped.""" diff --git a/src/unstract/migration/orchestrator.py b/src/unstract/migration/orchestrator.py new file mode 100644 index 0000000..f6832fd --- /dev/null +++ b/src/unstract/migration/orchestrator.py @@ -0,0 +1,66 @@ +"""Top-level ``migrate()`` entry point. + +Wires source/target ``PlatformClient`` instances, builds a +``MigrationContext``, runs each phase in strict topological order, and +returns a ``MigrationReport``. + +Phase order is owned here — phases must not call each other. Adding a new +entity type means: write a new ``Phase`` subclass and append it to +``PHASES`` at the right dependency position. +""" + +from __future__ import annotations + +import logging + +from unstract.migration.client import PlatformClient +from unstract.migration.context import MigrationContext, MigrationOptions, OrgEndpoint +from unstract.migration.exceptions import MigrationError +from unstract.migration.phases import AdapterPhase +from unstract.migration.phases.base import Phase +from unstract.migration.report import MigrationReport + +logger = logging.getLogger(__name__) + +# Strict dependency order. Each entry: (phase_name, phase_class). +# v1 vertical slice ships AdapterPhase only; remaining phases land in +# follow-up commits per the plan. +PHASES: list[tuple[str, type[Phase]]] = [ + ("adapter", AdapterPhase), +] + + +def migrate( + source: OrgEndpoint, + target: OrgEndpoint, + options: MigrationOptions | None = None, +) -> MigrationReport: + """Migrate configured resources from one org to another. + + Returns a ``MigrationReport`` even on partial failure; raises only on + setup errors or ``on_name_conflict='abort'`` collisions. + """ + opts = options or MigrationOptions() + ctx = MigrationContext( + source=PlatformClient(source), + target=PlatformClient(target), + options=opts, + ) + report = MigrationReport() + + for name, phase_cls in PHASES: + if not opts.includes(name): + report.skipped_phases.append(name) + logger.info("Phase '%s' skipped (excluded)", name) + continue + logger.info("=== Phase: %s ===", name) + try: + phase_cls(ctx).run(report) + except MigrationError as e: + report.aborted = True + report.abort_reason = str(e) + logger.error("Phase '%s' aborted: %s", name, e) + break + + report.remap_snapshot = ctx.remap.snapshot() + return report diff --git a/src/unstract/migration/phases/__init__.py b/src/unstract/migration/phases/__init__.py new file mode 100644 index 0000000..102f39a --- /dev/null +++ b/src/unstract/migration/phases/__init__.py @@ -0,0 +1,13 @@ +"""Per-entity migration phases. + +Each phase implements ``run(report)``, uses ``ctx.source`` / ``ctx.target`` +to drive HTTP, records ``ctx.remap`` entries for downstream phases. + +Dependency order is owned by ``orchestrator.migrate`` — phases must NOT +call each other directly. +""" + +from unstract.migration.phases.adapter import AdapterPhase +from unstract.migration.phases.base import Phase + +__all__ = ["AdapterPhase", "Phase"] diff --git a/src/unstract/migration/phases/adapter.py b/src/unstract/migration/phases/adapter.py new file mode 100644 index 0000000..5544107 --- /dev/null +++ b/src/unstract/migration/phases/adapter.py @@ -0,0 +1,106 @@ +"""Migrate adapters from source org to target org. + +Reference implementation for the get-or-create pattern: list-by-name GET +against target, POST create if missing, record source->target UUID in the +remap table for downstream phases. + +``AdapterInstanceManager.for_user(service_account)`` returns only +non-frictionless adapters, so frictionless onboarding adapters are +intentionally excluded from migration. +""" + +from __future__ import annotations + +import logging +from typing import Any + +from unstract.migration.exceptions import NameConflictError +from unstract.migration.phases.base import Phase +from unstract.migration.report import MigrationReport, PhaseResult + +logger = logging.getLogger(__name__) + +# Fields copied verbatim from source GET into target POST. Everything else +# (id, created_by, deprecation flags, icon, etc.) is either auto-set by the +# target backend or derived — carrying it would either be ignored or cause +# validation noise. +ADAPTER_POST_FIELDS = ( + "adapter_id", + "adapter_name", + "adapter_type", + "adapter_metadata", + "description", +) + + +class AdapterPhase(Phase): + name = "adapter" + + def run(self, report: MigrationReport) -> PhaseResult: + result = report.get_phase(self.name) + try: + src_summaries = self.ctx.source.list_adapters() + except Exception as e: + logger.exception("Failed to list source adapters: %s", e) + result.failed += 1 + result.errors.append(f"list source adapters: {e}") + return result + + logger.info("Found %d adapter(s) in source org", len(src_summaries)) + for summary in src_summaries: + self._migrate_one(summary, result) + return result + + def _migrate_one(self, summary: dict[str, Any], result: PhaseResult) -> None: + name = summary["adapter_name"] + atype = summary["adapter_type"] + src_id = summary["id"] + # List response omits adapter_metadata (see AdapterListSerializer); + # fetch the detail endpoint to pick up the decrypted metadata. + try: + src = self.ctx.source.get_adapter(src_id) + except Exception as e: + logger.exception("Failed to GET source adapter %s [%s] detail: %s", name, atype, e) + result.failed += 1 + result.errors.append(f"GET source detail {name} [{atype}]: {e}") + return + + try: + existing = self.ctx.target.list_adapters(name=name, adapter_type=atype) + except Exception as e: + logger.exception("Failed to GET adapter %s [%s] on target: %s", name, atype, e) + result.failed += 1 + result.errors.append(f"GET {name} [{atype}]: {e}") + return + + if existing: + tgt = existing[0] + if self.ctx.options.on_name_conflict == "abort": + raise NameConflictError( + f"adapter '{name}' [{atype}] already exists in target as {tgt['id']}" + ) + result.adopted += 1 + logger.info( + "adopted adapter '%s' [%s] src=%s -> tgt=%s", + name, atype, src_id, tgt["id"], + ) + elif self.ctx.options.dry_run: + result.skipped += 1 + logger.info("[dry-run] would create adapter '%s' [%s] src=%s", name, atype, src_id) + return + else: + payload = {k: src[k] for k in ADAPTER_POST_FIELDS if k in src and src[k] is not None} + try: + tgt = self.ctx.target.create_adapter(payload) + except Exception as e: + logger.exception("Failed to create adapter %s [%s]: %s", name, atype, e) + result.failed += 1 + result.errors.append(f"create {name} [{atype}]: {e}") + return + result.created += 1 + logger.info( + "created adapter '%s' [%s] src=%s -> tgt=%s", + name, atype, src_id, tgt["id"], + ) + + self.ctx.remap.record("adapter", src_id, tgt["id"]) diff --git a/src/unstract/migration/phases/base.py b/src/unstract/migration/phases/base.py new file mode 100644 index 0000000..c22f587 --- /dev/null +++ b/src/unstract/migration/phases/base.py @@ -0,0 +1,25 @@ +"""Base class for migration phases.""" + +from __future__ import annotations + +import logging +from abc import ABC, abstractmethod + +from unstract.migration.context import MigrationContext +from unstract.migration.report import MigrationReport, PhaseResult + +logger = logging.getLogger(__name__) + + +class Phase(ABC): + """Abstract phase. One subclass per entity type.""" + + name: str = "" + + def __init__(self, ctx: MigrationContext): + self.ctx = ctx + + @abstractmethod + def run(self, report: MigrationReport) -> PhaseResult: + """Migrate all entities of this phase's type. Idempotent across runs.""" + raise NotImplementedError diff --git a/src/unstract/migration/report.py b/src/unstract/migration/report.py new file mode 100644 index 0000000..e07ea75 --- /dev/null +++ b/src/unstract/migration/report.py @@ -0,0 +1,111 @@ +"""Structured report produced by ``migrate()``. + +Tracks per-phase counts (created / adopted / skipped / failed) and a final +remap snapshot. Renders to a rich-formatted table when ``rich`` is +available; falls back to plain text otherwise. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any + + +@dataclass +class PhaseResult: + name: str + created: int = 0 + adopted: int = 0 + skipped: int = 0 + failed: int = 0 + errors: list[str] = field(default_factory=list) + + +@dataclass +class MigrationReport: + phases: list[PhaseResult] = field(default_factory=list) + skipped_phases: list[str] = field(default_factory=list) + remap_snapshot: dict[str, dict[str, str]] = field(default_factory=dict) + aborted: bool = False + abort_reason: str | None = None + + def get_phase(self, name: str) -> PhaseResult: + for p in self.phases: + if p.name == name: + return p + result = PhaseResult(name=name) + self.phases.append(result) + return result + + def render(self) -> str: + """Render as a rich table when available, otherwise plain text.""" + try: + from io import StringIO + + from rich.console import Console + from rich.table import Table + except ImportError: + return self._render_plain() + + buf = StringIO() + console = Console(file=buf, force_terminal=False, width=100) + table = Table(title="Migration Report") + for col in ("Phase", "Created", "Adopted", "Skipped", "Failed"): + table.add_column(col, justify="right" if col != "Phase" else "left") + for p in self.phases: + table.add_row(p.name, str(p.created), str(p.adopted), str(p.skipped), str(p.failed)) + console.print(table) + if self.skipped_phases: + console.print(f"[dim]Skipped phases:[/dim] {', '.join(self.skipped_phases)}") + if self.remap_snapshot: + remap = Table(title="Source -> Target UUID Map") + remap.add_column("Entity") + remap.add_column("Source UUID") + remap.add_column("Target UUID") + for entity, mapping in self.remap_snapshot.items(): + for src, tgt in mapping.items(): + remap.add_row(entity, src, tgt) + console.print(remap) + if self.aborted: + console.print(f"[red]ABORTED:[/red] {self.abort_reason}") + return buf.getvalue() + + def _render_plain(self) -> str: + lines = ["Migration Report", "=" * 60] + header = f"{'Phase':<24}{'Created':>10}{'Adopted':>10}{'Skipped':>10}{'Failed':>10}" + lines.append(header) + for p in self.phases: + lines.append( + f"{p.name:<24}{p.created:>10}{p.adopted:>10}{p.skipped:>10}{p.failed:>10}" + ) + if self.skipped_phases: + lines.append(f"Skipped phases: {', '.join(self.skipped_phases)}") + if self.remap_snapshot: + lines.append("") + lines.append("Source -> Target UUID Map") + lines.append("-" * 60) + for entity, mapping in self.remap_snapshot.items(): + for src, tgt in mapping.items(): + lines.append(f" {entity:<12} {src} -> {tgt}") + if self.aborted: + lines.append(f"ABORTED: {self.abort_reason}") + return "\n".join(lines) + + def as_dict(self) -> dict[str, Any]: + return { + "phases": [ + { + "name": p.name, + "created": p.created, + "adopted": p.adopted, + "skipped": p.skipped, + "failed": p.failed, + "errors": list(p.errors), + } + for p in self.phases + ], + "skipped_phases": list(self.skipped_phases), + "remap_snapshot": self.remap_snapshot, + "aborted": self.aborted, + "abort_reason": self.abort_reason, + } diff --git a/src/unstract/migration/walker.py b/src/unstract/migration/walker.py new file mode 100644 index 0000000..6e43553 --- /dev/null +++ b/src/unstract/migration/walker.py @@ -0,0 +1,32 @@ +"""JSON walker that rewrites embedded source UUIDs to target UUIDs. + +Used by phases whose payloads carry foreign-key UUIDs inside JSON fields +(e.g. ``tool_instance.metadata``). Unknown UUIDs pass through untouched so +we don't accidentally rewrite an unrelated identifier that just happens +to look like a UUID. +""" + +from __future__ import annotations + +import re +from typing import Any + +from unstract.migration.context import RemapTable + +UUID_RE = re.compile( + r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$", + re.IGNORECASE, +) + + +def remap_uuids(obj: Any, remap: RemapTable) -> Any: + """Walk a JSON-shaped value; replace any string that looks like a UUID + AND has a known mapping. Unknown UUIDs pass through untouched. + """ + if isinstance(obj, dict): + return {k: remap_uuids(v, remap) for k, v in obj.items()} + if isinstance(obj, list): + return [remap_uuids(v, remap) for v in obj] + if isinstance(obj, str) and UUID_RE.match(obj): + return remap.resolve_any(obj) or obj + return obj diff --git a/tests/migration/__init__.py b/tests/migration/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/migration/test_adapter_phase.py b/tests/migration/test_adapter_phase.py new file mode 100644 index 0000000..d6827ff --- /dev/null +++ b/tests/migration/test_adapter_phase.py @@ -0,0 +1,155 @@ +"""Tests for ``AdapterPhase``. + +Uses an in-process fake ``PlatformClient`` to avoid real HTTP. Verifies: +- happy path: source has N adapters, target gets N POSTs, all remapped +- idempotency: re-run with target already populated → zero POSTs, all adopted +- dry-run: zero POSTs, all skipped +- on_name_conflict='abort' raises on existing +""" + +from __future__ import annotations + +import pytest + +from unstract.migration.context import ( + MigrationContext, + MigrationOptions, + RemapTable, +) +from unstract.migration.exceptions import NameConflictError +from unstract.migration.phases.adapter import AdapterPhase +from unstract.migration.report import MigrationReport + + +class FakeClient: + """Minimal in-memory stand-in for ``PlatformClient``.""" + + def __init__(self, adapters: list[dict] | None = None): + # Stored as a list of dicts; mutated by create_adapter. + self.adapters: list[dict] = list(adapters or []) + self.posts: list[dict] = [] + self._next_id = 1 + + def list_adapters(self, *, name=None, adapter_type=None): + result = self.adapters + if name is not None: + result = [a for a in result if a["adapter_name"] == name] + if adapter_type is not None: + result = [a for a in result if a["adapter_type"] == adapter_type] + # Mimic AdapterListSerializer — strip adapter_metadata from list output. + return [{k: v for k, v in a.items() if k != "adapter_metadata"} for a in result] + + def get_adapter(self, adapter_pk): + for a in self.adapters: + if a["id"] == adapter_pk: + return a + raise KeyError(adapter_pk) + + def create_adapter(self, payload): + new = dict(payload) + new["id"] = f"tgt-{self._next_id:08d}-0000-0000-0000-000000000000" + self._next_id += 1 + self.adapters.append(new) + self.posts.append(new) + return new + + +def _src_adapter(id_, name, atype="LLM"): + return { + "id": id_, + "adapter_id": "openai-llm-v2", + "adapter_name": name, + "adapter_type": atype, + "adapter_metadata": {"api_key": "sk-secret", "model": "gpt-4"}, + "description": f"{name} desc", + } + + +def _ctx(source: FakeClient, target: FakeClient, **opt_overrides): + ctx = MigrationContext( + source=source, + target=target, + options=MigrationOptions(**opt_overrides), + remap=RemapTable(), + ) + return ctx + + +def test_happy_path_creates_all_and_records_remap(): + src = FakeClient( + [ + _src_adapter("src-a", "OpenAI Prod"), + _src_adapter("src-b", "Mistral Stg", atype="EMBEDDING"), + ] + ) + tgt = FakeClient() + ctx = _ctx(src, tgt) + report = MigrationReport() + + result = AdapterPhase(ctx).run(report) + + assert result.created == 2 + assert result.adopted == 0 + assert result.failed == 0 + assert len(tgt.posts) == 2 + assert ctx.remap.resolve("adapter", "src-a") == tgt.posts[0]["id"] + assert ctx.remap.resolve("adapter", "src-b") == tgt.posts[1]["id"] + + +def test_idempotency_zero_creates_on_rerun(): + src_adapters = [_src_adapter("src-a", "OpenAI Prod")] + src = FakeClient(src_adapters) + # Target pre-populated with the same name+type — simulates a prior run. + tgt = FakeClient( + [ + { + "id": "preexisting", + "adapter_id": "openai-llm-v2", + "adapter_name": "OpenAI Prod", + "adapter_type": "LLM", + "adapter_metadata": {}, + } + ] + ) + ctx = _ctx(src, tgt, on_name_conflict="adopt") + report = MigrationReport() + + result = AdapterPhase(ctx).run(report) + + assert result.created == 0 + assert result.adopted == 1 + assert tgt.posts == [] # no new POSTs + assert ctx.remap.resolve("adapter", "src-a") == "preexisting" + + +def test_dry_run_makes_no_posts(): + src = FakeClient([_src_adapter("src-a", "OpenAI Prod")]) + tgt = FakeClient() + ctx = _ctx(src, tgt, dry_run=True) + report = MigrationReport() + + result = AdapterPhase(ctx).run(report) + + assert result.skipped == 1 + assert result.created == 0 + assert tgt.posts == [] + + +def test_abort_on_name_conflict_raises(): + src = FakeClient([_src_adapter("src-a", "OpenAI Prod")]) + tgt = FakeClient( + [ + { + "id": "preexisting", + "adapter_id": "openai-llm-v2", + "adapter_name": "OpenAI Prod", + "adapter_type": "LLM", + "adapter_metadata": {}, + } + ] + ) + ctx = _ctx(src, tgt, on_name_conflict="abort") + report = MigrationReport() + + with pytest.raises(NameConflictError): + AdapterPhase(ctx).run(report) diff --git a/tests/migration/test_remap_table.py b/tests/migration/test_remap_table.py new file mode 100644 index 0000000..3045326 --- /dev/null +++ b/tests/migration/test_remap_table.py @@ -0,0 +1,36 @@ +"""Tests for ``RemapTable``.""" + +from unstract.migration.context import RemapTable + + +def test_record_and_resolve_per_entity(): + t = RemapTable() + t.record("adapter", "src-1", "tgt-1") + t.record("adapter", "src-2", "tgt-2") + t.record("connector", "src-1", "tgt-99") + + assert t.resolve("adapter", "src-1") == "tgt-1" + assert t.resolve("adapter", "src-2") == "tgt-2" + assert t.resolve("connector", "src-1") == "tgt-99" + + +def test_resolve_missing_returns_none(): + t = RemapTable() + assert t.resolve("adapter", "nope") is None + assert t.resolve_any("nope") is None + + +def test_resolve_any_searches_across_entities(): + t = RemapTable() + t.record("adapter", "src-a", "tgt-a") + t.record("workflow", "src-w", "tgt-w") + assert t.resolve_any("src-a") == "tgt-a" + assert t.resolve_any("src-w") == "tgt-w" + + +def test_snapshot_is_independent_copy(): + t = RemapTable() + t.record("adapter", "src-1", "tgt-1") + snap = t.snapshot() + t.record("adapter", "src-2", "tgt-2") + assert "src-2" not in snap["adapter"] diff --git a/tests/migration/test_walker.py b/tests/migration/test_walker.py new file mode 100644 index 0000000..44107bc --- /dev/null +++ b/tests/migration/test_walker.py @@ -0,0 +1,54 @@ +"""Tests for ``remap_uuids``.""" + +from unstract.migration.context import RemapTable +from unstract.migration.walker import remap_uuids + +SRC_A = "11111111-1111-1111-1111-111111111111" +TGT_A = "22222222-2222-2222-2222-222222222222" +SRC_B = "33333333-3333-3333-3333-333333333333" +TGT_B = "44444444-4444-4444-4444-444444444444" +UNRELATED = "55555555-5555-5555-5555-555555555555" + + +def _populated_remap(): + t = RemapTable() + t.record("adapter", SRC_A, TGT_A) + t.record("workflow", SRC_B, TGT_B) + return t + + +def test_remaps_mapped_uuid_string(): + assert remap_uuids(SRC_A, _populated_remap()) == TGT_A + + +def test_leaves_unmapped_uuid_untouched(): + assert remap_uuids(UNRELATED, _populated_remap()) == UNRELATED + + +def test_leaves_non_uuid_string_alone(): + assert remap_uuids("hello-world", _populated_remap()) == "hello-world" + + +def test_remaps_inside_nested_dict_and_list(): + payload = { + "id": SRC_A, + "config": { + "refs": [SRC_B, "not-a-uuid", UNRELATED], + "nested": {"adapter_id": SRC_A}, + }, + "count": 42, + } + result = remap_uuids(payload, _populated_remap()) + assert result == { + "id": TGT_A, + "config": { + "refs": [TGT_B, "not-a-uuid", UNRELATED], + "nested": {"adapter_id": TGT_A}, + }, + "count": 42, + } + + +def test_handles_non_string_scalars(): + payload = {"a": 1, "b": True, "c": None, "d": 3.14} + assert remap_uuids(payload, _populated_remap()) == payload diff --git a/uv.lock b/uv.lock index 25478df..57e8a81 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 2 +revision = 3 requires-python = ">=3.11" [[package]] @@ -102,6 +102,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/0a/4c/925909008ed5a988ccbb72dcc897407e5d6d3bd72410d69e051fc0c14647/charset_normalizer-3.4.4-py3-none-any.whl", hash = "sha256:7a32c560861a02ff789ad905a2fe94e3f840803362c84fecf1851cb4cf3dc37f", size = 53402, upload-time = "2025-10-14T04:42:31.76Z" }, ] +[[package]] +name = "click" +version = "8.4.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "colorama", marker = "sys_platform == 'win32'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/9b/98/518d8e5081007684232226f475082b30087d0f585e8457db087298259f49/click-8.4.1.tar.gz", hash = "sha256:918b5633eddf6b41c32d4f454bf0de810065c74e3f7dbf8ee5452f8be88d3e96", size = 353007, upload-time = "2026-05-22T04:08:37.769Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c7/0d/67e5b4109ea4a837e80daa87c2c696711955e40449a97e8926672534def2/click-8.4.1-py3-none-any.whl", hash = "sha256:482be17c6991b8c19c5429a1e995d9b0efdbb63172824c41f99965dc0ade8ec2", size = 116639, upload-time = "2026-05-22T04:08:35.26Z" }, +] + [[package]] name = "colorama" version = "0.4.6" @@ -282,6 +294,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/cb/b1/3846dd7f199d53cb17f49cba7e651e9ce294d8497c8c150530ed11865bb8/iniconfig-2.3.0-py3-none-any.whl", hash = "sha256:f631c04d2c48c52b84d0d0549c99ff3859c98df65b3101406327ecc7d53fbf12", size = 7484, upload-time = "2025-10-18T21:55:41.639Z" }, ] +[[package]] +name = "markdown-it-py" +version = "4.2.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "mdurl" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/06/ff/7841249c247aa650a76b9ee4bbaeae59370dc8bfd2f6c01f3630c35eb134/markdown_it_py-4.2.0.tar.gz", hash = "sha256:04a21681d6fbb623de53f6f364d352309d4094dd4194040a10fd51833e418d49", size = 82454, upload-time = "2026-05-07T12:08:28.36Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b3/81/4da04ced5a082363ecfa159c010d200ecbd959ae410c10c0264a38cac0f5/markdown_it_py-4.2.0-py3-none-any.whl", hash = "sha256:9f7ebbcd14fe59494226453aed97c1070d83f8d24b6fc3a3bcf9a38092641c4a", size = 91687, upload-time = "2026-05-07T12:08:27.182Z" }, +] + [[package]] name = "mbstrdecoder" version = "1.1.4" @@ -294,6 +318,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/30/ac/5ce64a1d4cce00390beab88622a290420401f1cabf05caf2fc0995157c21/mbstrdecoder-1.1.4-py3-none-any.whl", hash = "sha256:03dae4ec50ec0d2ff4743e63fdbd5e0022815857494d35224b60775d3d934a8c", size = 7933, upload-time = "2025-01-18T10:07:29.562Z" }, ] +[[package]] +name = "mdurl" +version = "0.1.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/d6/54/cfe61301667036ec958cb99bd3efefba235e65cdeb9c84d24a8293ba1d90/mdurl-0.1.2.tar.gz", hash = "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba", size = 8729, upload-time = "2022-08-14T12:40:10.846Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b3/38/89ba8ad64ae25be8de66a6d463314cf1eb366222074cfda9ee839c56a4b4/mdurl-0.1.2-py3-none-any.whl", hash = "sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8", size = 9979, upload-time = "2022-08-14T12:40:09.779Z" }, +] + [[package]] name = "mypy" version = "1.10.1" @@ -593,6 +626,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/1e/db/4254e3eabe8020b458f1a747140d32277ec7a271daf1d235b70dc0b4e6e3/requests-2.32.5-py3-none-any.whl", hash = "sha256:2462f94637a34fd532264295e186976db0f5d453d1cdd31473c85a6a161affb6", size = 64738, upload-time = "2025-08-18T20:46:00.542Z" }, ] +[[package]] +name = "rich" +version = "15.0.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "markdown-it-py" }, + { name = "pygments" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/c0/8f/0722ca900cc807c13a6a0c696dacf35430f72e0ec571c4275d2371fca3e9/rich-15.0.0.tar.gz", hash = "sha256:edd07a4824c6b40189fb7ac9bc4c52536e9780fbbfbddf6f1e2502c31b068c36", size = 230680, upload-time = "2026-04-12T08:24:00.75Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/82/3b/64d4899d73f91ba49a8c18a8ff3f0ea8f1c1d75481760df8c68ef5235bf5/rich-15.0.0-py3-none-any.whl", hash = "sha256:33bd4ef74232fb73fe9279a257718407f169c09b78a87ad3d296f548e27de0bb", size = 310654, upload-time = "2026-04-12T08:24:02.83Z" }, +] + [[package]] name = "ruff" version = "0.15.1" @@ -757,6 +803,12 @@ dependencies = [ { name = "tenacity" }, ] +[package.optional-dependencies] +migration = [ + { name = "click" }, + { name = "rich" }, +] + [package.dev-dependencies] dev = [ { name = "docutils" }, @@ -789,9 +841,12 @@ test = [ [package.metadata] requires-dist = [ + { name = "click", marker = "extra == 'migration'", specifier = ">=8.1" }, { name = "requests", specifier = ">=2.32.3" }, + { name = "rich", marker = "extra == 'migration'", specifier = ">=13.7" }, { name = "tenacity", specifier = ">=8.2.0" }, ] +provides-extras = ["migration"] [package.metadata.requires-dev] dev = [ From 17aa92349f7b5ecef2af29255f220c82a3d1e0a9 Mon Sep 17 00:00:00 2001 From: Chandrasekharan M Date: Sun, 24 May 2026 01:18:56 +0530 Subject: [PATCH 02/25] chore(migration): drop cross-repo refs from code comments Per the project's code-comments guidance: comments explain WHY in generic terms, not 'see file X in repo Y'. Path/class references rot when files move and don't help a future reader of this package who may not have the backend repo open. Behavior unchanged. 13 unit tests + integration smoke still green. --- src/unstract/migration/client.py | 11 ++++------- src/unstract/migration/context.py | 10 +++------- src/unstract/migration/phases/adapter.py | 9 ++++----- 3 files changed, 11 insertions(+), 19 deletions(-) diff --git a/src/unstract/migration/client.py b/src/unstract/migration/client.py index 4df7185..c6ded7b 100644 --- a/src/unstract/migration/client.py +++ b/src/unstract/migration/client.py @@ -1,14 +1,11 @@ """Thin Platform API client for the migration subpackage. One ``PlatformClient`` instance per ``OrgEndpoint``. Methods are entity- -scoped (``list_adapters``, ``create_adapter``, ...) so the call sites in -phases read like business logic, not HTTP plumbing. +scoped (``list_adapters``, ``create_adapter``, ...) so call sites in phases +read like business logic, not HTTP plumbing. -URL shape (see ``backend/middleware/organization_middleware.py``): - {base_url}/api/v2/unstract/{org_id}// - -Auth: ``Authorization: Bearer `` (see -``backend/account_v2/custom_auth_middleware.py``). +URL shape: ``{base_url}/{api_path_prefix}/unstract/{organization_id}//`` +Auth: ``Authorization: Bearer ``. """ from __future__ import annotations diff --git a/src/unstract/migration/context.py b/src/unstract/migration/context.py index 8d84f1a..98d7ce7 100644 --- a/src/unstract/migration/context.py +++ b/src/unstract/migration/context.py @@ -23,13 +23,9 @@ class OrgEndpoint: """One end of a migration: where to talk to and who to talk as. - ``organization_id`` is the slug embedded in the URL path (see - ``OrganizationMiddleware`` regex ``/api/(v1|v2)/unstract//...``). - ``platform_key`` is the bearer UUID issued by an org admin. - - ``api_path_prefix`` defaults to ``api/v1`` to match the OSS docker - compose ``PATH_PREFIX`` env. Cloud / on-prem envs that mount on a - different prefix can override (e.g. ``api/v2``). + ``organization_id`` is the slug embedded in the URL path; the bearer + Platform API key must belong to this org. ``api_path_prefix`` matches + the deployment's URL prefix (defaults to ``api/v1``). """ base_url: str diff --git a/src/unstract/migration/phases/adapter.py b/src/unstract/migration/phases/adapter.py index 5544107..3885cb6 100644 --- a/src/unstract/migration/phases/adapter.py +++ b/src/unstract/migration/phases/adapter.py @@ -4,9 +4,9 @@ against target, POST create if missing, record source->target UUID in the remap table for downstream phases. -``AdapterInstanceManager.for_user(service_account)`` returns only -non-frictionless adapters, so frictionless onboarding adapters are -intentionally excluded from migration. +Frictionless onboarding adapters are excluded — the backend's +service-account queryset already filters them out, so migration never +sees them. """ from __future__ import annotations @@ -55,8 +55,7 @@ def _migrate_one(self, summary: dict[str, Any], result: PhaseResult) -> None: name = summary["adapter_name"] atype = summary["adapter_type"] src_id = summary["id"] - # List response omits adapter_metadata (see AdapterListSerializer); - # fetch the detail endpoint to pick up the decrypted metadata. + # List response omits adapter_metadata; fetch detail to pick it up. try: src = self.ctx.source.get_adapter(src_id) except Exception as e: From f28c3c37173620a40cab535100727ab6e99c7bba Mon Sep 17 00:00:00 2001 From: Chandrasekharan M Date: Sun, 24 May 2026 01:51:57 +0530 Subject: [PATCH 03/25] feat(migration): connector + tag phases MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Two independent leaf phases following the AdapterPhase template: list -> per-id GET -> POST/adopt by name, record source->target remap. - ConnectorPhase: skips Unstract Cloud Storage rows (catalog id is redacted on the wire — target re-provisions per-org). OAuth-backed connectors land without refresh tokens; operator re-authorises on target. - TagPhase: simplest entity — name + description, no encryption, no list-vs-detail divergence. Orchestrator PHASES order extended; 22 unit tests pass. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/unstract/migration/client.py | 37 +++++ src/unstract/migration/orchestrator.py | 9 +- src/unstract/migration/phases/__init__.py | 4 +- src/unstract/migration/phases/connector.py | 119 +++++++++++++++ src/unstract/migration/phases/tag.py | 79 ++++++++++ tests/migration/test_connector_phase.py | 160 +++++++++++++++++++++ tests/migration/test_tag_phase.py | 104 ++++++++++++++ 7 files changed, 508 insertions(+), 4 deletions(-) create mode 100644 src/unstract/migration/phases/connector.py create mode 100644 src/unstract/migration/phases/tag.py create mode 100644 tests/migration/test_connector_phase.py create mode 100644 tests/migration/test_tag_phase.py diff --git a/src/unstract/migration/client.py b/src/unstract/migration/client.py index c6ded7b..007b318 100644 --- a/src/unstract/migration/client.py +++ b/src/unstract/migration/client.py @@ -96,3 +96,40 @@ def get_adapter(self, adapter_pk: str) -> dict[str, Any]: def create_adapter(self, payload: dict[str, Any]) -> dict[str, Any]: return self._request("POST", "adapter/", json=payload) + + # ----- connectors ----- + + def list_connectors( + self, + *, + name: str | None = None, + connector_type: str | None = None, + ) -> list[dict[str, Any]]: + """List connectors in this org, optionally filtered by name and/or type.""" + params: dict[str, Any] = {} + if name is not None: + params["connector_name"] = name + if connector_type is not None: + params["connector_type"] = connector_type + result = self._request("GET", "connector/", params=params) + return result if isinstance(result, list) else result.get("results", []) + + def get_connector(self, connector_pk: str) -> dict[str, Any]: + return self._request("GET", f"connector/{connector_pk}/") + + def create_connector(self, payload: dict[str, Any]) -> dict[str, Any]: + return self._request("POST", "connector/", json=payload) + + # ----- tags ----- + + def list_tags(self, *, name: str | None = None) -> list[dict[str, Any]]: + """List tags in this org, optionally filtered by exact name.""" + params: dict[str, Any] = {} + if name is not None: + params["name"] = name + result = self._request("GET", "tags/", params=params) + # Tags endpoint uses pagination — accept either bare list or paginated envelope. + return result if isinstance(result, list) else result.get("results", []) + + def create_tag(self, payload: dict[str, Any]) -> dict[str, Any]: + return self._request("POST", "tags/", json=payload) diff --git a/src/unstract/migration/orchestrator.py b/src/unstract/migration/orchestrator.py index f6832fd..f3e78b6 100644 --- a/src/unstract/migration/orchestrator.py +++ b/src/unstract/migration/orchestrator.py @@ -16,17 +16,20 @@ from unstract.migration.client import PlatformClient from unstract.migration.context import MigrationContext, MigrationOptions, OrgEndpoint from unstract.migration.exceptions import MigrationError -from unstract.migration.phases import AdapterPhase +from unstract.migration.phases import AdapterPhase, ConnectorPhase, TagPhase from unstract.migration.phases.base import Phase from unstract.migration.report import MigrationReport logger = logging.getLogger(__name__) # Strict dependency order. Each entry: (phase_name, phase_class). -# v1 vertical slice ships AdapterPhase only; remaining phases land in -# follow-up commits per the plan. +# Adapter, connector, tag are independent leaf phases. Downstream phases +# (custom_tool, workflow, tool_instance, workflow_endpoint) land later +# and consume the remap entries these produce. PHASES: list[tuple[str, type[Phase]]] = [ ("adapter", AdapterPhase), + ("connector", ConnectorPhase), + ("tag", TagPhase), ] diff --git a/src/unstract/migration/phases/__init__.py b/src/unstract/migration/phases/__init__.py index 102f39a..ff6cac5 100644 --- a/src/unstract/migration/phases/__init__.py +++ b/src/unstract/migration/phases/__init__.py @@ -9,5 +9,7 @@ from unstract.migration.phases.adapter import AdapterPhase from unstract.migration.phases.base import Phase +from unstract.migration.phases.connector import ConnectorPhase +from unstract.migration.phases.tag import TagPhase -__all__ = ["AdapterPhase", "Phase"] +__all__ = ["AdapterPhase", "ConnectorPhase", "Phase", "TagPhase"] diff --git a/src/unstract/migration/phases/connector.py b/src/unstract/migration/phases/connector.py new file mode 100644 index 0000000..9794fc9 --- /dev/null +++ b/src/unstract/migration/phases/connector.py @@ -0,0 +1,119 @@ +"""Migrate connectors from source org to target org. + +Same list -> per-id GET -> POST/adopt pattern as AdapterPhase. Two +connector-specific wrinkles: + +1. **Auto-provisioned UCS connectors are skipped.** The Unstract Cloud + Storage connector has its ``connector_metadata`` redacted to ``{}`` + on the wire, so we cannot reliably reconstruct it on the target. + The target org is expected to have its own UCS row already; downstream + phases (workflow endpoints) must remap by ``connector_id`` lookup + rather than relying on the remap table here. + +2. **OAuth ``connector_auth`` is stripped from responses.** Tokens are + stored in a sibling ``ConnectorAuth`` row that the public API never + exposes, so OAuth-backed connectors land on the target without + refresh tokens. Operator must re-authorise on target. +""" + +from __future__ import annotations + +import logging +from typing import Any + +from unstract.migration.exceptions import NameConflictError +from unstract.migration.phases.base import Phase +from unstract.migration.report import MigrationReport, PhaseResult + +logger = logging.getLogger(__name__) + +UCS_CONNECTOR_ID = "pcs|b8cd25cd-4452-4d54-bd5e-e7d71459b702" + +CONNECTOR_POST_FIELDS = ( + "connector_id", + "connector_name", + "connector_metadata", + "connector_version", + "connector_type", + "shared_to_org", +) + + +class ConnectorPhase(Phase): + name = "connector" + + def run(self, report: MigrationReport) -> PhaseResult: + result = report.get_phase(self.name) + try: + src_summaries = self.ctx.source.list_connectors() + except Exception as e: + logger.exception("Failed to list source connectors: %s", e) + result.failed += 1 + result.errors.append(f"list source connectors: {e}") + return result + + logger.info("Found %d connector(s) in source org", len(src_summaries)) + for summary in src_summaries: + self._migrate_one(summary, result) + return result + + def _migrate_one(self, summary: dict[str, Any], result: PhaseResult) -> None: + name = summary["connector_name"] + src_id = summary["id"] + catalog_id = summary.get("connector_id") + + if catalog_id == UCS_CONNECTOR_ID: + logger.info( + "skipping UCS connector '%s' (src=%s) — auto-provisioned per-org", + name, src_id, + ) + result.skipped += 1 + return + + try: + src = self.ctx.source.get_connector(src_id) + except Exception as e: + logger.exception("Failed to GET source connector %s detail: %s", name, e) + result.failed += 1 + result.errors.append(f"GET source detail {name}: {e}") + return + + try: + existing = self.ctx.target.list_connectors(name=name) + except Exception as e: + logger.exception("Failed to GET connector %s on target: %s", name, e) + result.failed += 1 + result.errors.append(f"GET {name}: {e}") + return + + if existing: + tgt = existing[0] + if self.ctx.options.on_name_conflict == "abort": + raise NameConflictError( + f"connector '{name}' already exists in target as {tgt['id']}" + ) + result.adopted += 1 + logger.info( + "adopted connector '%s' src=%s -> tgt=%s", + name, src_id, tgt["id"], + ) + elif self.ctx.options.dry_run: + result.skipped += 1 + logger.info("[dry-run] would create connector '%s' src=%s", name, src_id) + return + else: + payload = {k: src[k] for k in CONNECTOR_POST_FIELDS if k in src and src[k] is not None} + try: + tgt = self.ctx.target.create_connector(payload) + except Exception as e: + logger.exception("Failed to create connector %s: %s", name, e) + result.failed += 1 + result.errors.append(f"create {name}: {e}") + return + result.created += 1 + logger.info( + "created connector '%s' src=%s -> tgt=%s", + name, src_id, tgt["id"], + ) + + self.ctx.remap.record("connector", src_id, tgt["id"]) diff --git a/src/unstract/migration/phases/tag.py b/src/unstract/migration/phases/tag.py new file mode 100644 index 0000000..1d6a81b --- /dev/null +++ b/src/unstract/migration/phases/tag.py @@ -0,0 +1,79 @@ +"""Migrate tags from source org to target org. + +Tags are flat (``name`` + ``description``) with a per-org uniqueness +constraint on ``name``. No metadata, no encryption, no list-vs-detail +divergence — the simplest entity in the migration set. + +List endpoint paginates; ``PlatformClient.list_tags`` already unwraps +the envelope. +""" + +from __future__ import annotations + +import logging +from typing import Any + +from unstract.migration.exceptions import NameConflictError +from unstract.migration.phases.base import Phase +from unstract.migration.report import MigrationReport, PhaseResult + +logger = logging.getLogger(__name__) + +TAG_POST_FIELDS = ("name", "description") + + +class TagPhase(Phase): + name = "tag" + + def run(self, report: MigrationReport) -> PhaseResult: + result = report.get_phase(self.name) + try: + src_tags = self.ctx.source.list_tags() + except Exception as e: + logger.exception("Failed to list source tags: %s", e) + result.failed += 1 + result.errors.append(f"list source tags: {e}") + return result + + logger.info("Found %d tag(s) in source org", len(src_tags)) + for src in src_tags: + self._migrate_one(src, result) + return result + + def _migrate_one(self, src: dict[str, Any], result: PhaseResult) -> None: + name = src["name"] + src_id = src["id"] + + try: + existing = self.ctx.target.list_tags(name=name) + except Exception as e: + logger.exception("Failed to GET tag %s on target: %s", name, e) + result.failed += 1 + result.errors.append(f"GET {name}: {e}") + return + + if existing: + tgt = existing[0] + if self.ctx.options.on_name_conflict == "abort": + raise NameConflictError( + f"tag '{name}' already exists in target as {tgt['id']}" + ) + result.adopted += 1 + logger.info("adopted tag '%s' src=%s -> tgt=%s", name, src_id, tgt["id"]) + elif self.ctx.options.dry_run: + result.skipped += 1 + logger.info("[dry-run] would create tag '%s' src=%s", name, src_id) + return + else: + payload = {k: src[k] for k in TAG_POST_FIELDS if k in src and src[k] is not None} + try: + tgt = self.ctx.target.create_tag(payload) + except Exception as e: + logger.exception("Failed to create tag %s: %s", name, e) + result.failed += 1 + result.errors.append(f"create {name}: {e}") + return + result.created += 1 + logger.info("created tag '%s' src=%s -> tgt=%s", name, src_id, tgt["id"]) + + self.ctx.remap.record("tag", src_id, tgt["id"]) diff --git a/tests/migration/test_connector_phase.py b/tests/migration/test_connector_phase.py new file mode 100644 index 0000000..e4341b3 --- /dev/null +++ b/tests/migration/test_connector_phase.py @@ -0,0 +1,160 @@ +"""Tests for ``ConnectorPhase``. + +Mirrors the adapter phase suite — happy path, idempotency, dry-run, +abort — plus connector-specific behavior: UCS auto-provisioned rows are +skipped without consulting the target. +""" + +from __future__ import annotations + +import pytest + +from unstract.migration.context import ( + MigrationContext, + MigrationOptions, + RemapTable, +) +from unstract.migration.exceptions import NameConflictError +from unstract.migration.phases.connector import ( + UCS_CONNECTOR_ID, + ConnectorPhase, +) +from unstract.migration.report import MigrationReport + + +class FakeClient: + def __init__(self, connectors: list[dict] | None = None): + self.connectors: list[dict] = list(connectors or []) + self.posts: list[dict] = [] + self._next_id = 1 + + def list_connectors(self, *, name=None, connector_type=None): + result = self.connectors + if name is not None: + result = [c for c in result if c["connector_name"] == name] + if connector_type is not None: + result = [c for c in result if c.get("connector_type") == connector_type] + return list(result) + + def get_connector(self, connector_pk): + for c in self.connectors: + if c["id"] == connector_pk: + return c + raise KeyError(connector_pk) + + def create_connector(self, payload): + new = dict(payload) + new["id"] = f"tgt-{self._next_id:08d}-0000-0000-0000-000000000000" + self._next_id += 1 + self.connectors.append(new) + self.posts.append(new) + return new + + +def _src(id_, name, catalog_id="postgres|abc", ctype="INPUT"): + return { + "id": id_, + "connector_id": catalog_id, + "connector_name": name, + "connector_type": ctype, + "connector_version": "1.0", + "connector_metadata": {"host": "db.example.com", "password": "secret"}, + "shared_to_org": False, + } + + +def _ctx(source, target, **opt_overrides): + return MigrationContext( + source=source, + target=target, + options=MigrationOptions(**opt_overrides), + remap=RemapTable(), + ) + + +def test_happy_path_creates_all_and_records_remap(): + src = FakeClient([_src("src-a", "Prod PG"), _src("src-b", "Stg S3", "s3|xyz")]) + tgt = FakeClient() + ctx = _ctx(src, tgt) + report = MigrationReport() + + result = ConnectorPhase(ctx).run(report) + + assert result.created == 2 + assert result.adopted == 0 + assert result.skipped == 0 + assert len(tgt.posts) == 2 + assert ctx.remap.resolve("connector", "src-a") == tgt.posts[0]["id"] + assert ctx.remap.resolve("connector", "src-b") == tgt.posts[1]["id"] + + +def test_ucs_connector_skipped_without_target_lookup(): + """UCS rows must be skipped pre-flight — no POST, no remap entry.""" + src = FakeClient([_src("src-ucs", "User Storage", catalog_id=UCS_CONNECTOR_ID)]) + tgt = FakeClient() + ctx = _ctx(src, tgt) + report = MigrationReport() + + result = ConnectorPhase(ctx).run(report) + + assert result.skipped == 1 + assert result.created == 0 + assert tgt.posts == [] + assert ctx.remap.resolve("connector", "src-ucs") is None + + +def test_idempotency_zero_creates_on_rerun(): + src = FakeClient([_src("src-a", "Prod PG")]) + tgt = FakeClient( + [ + { + "id": "preexisting", + "connector_id": "postgres|abc", + "connector_name": "Prod PG", + "connector_type": "INPUT", + "connector_metadata": {}, + } + ] + ) + ctx = _ctx(src, tgt, on_name_conflict="adopt") + report = MigrationReport() + + result = ConnectorPhase(ctx).run(report) + + assert result.created == 0 + assert result.adopted == 1 + assert tgt.posts == [] + assert ctx.remap.resolve("connector", "src-a") == "preexisting" + + +def test_dry_run_makes_no_posts(): + src = FakeClient([_src("src-a", "Prod PG")]) + tgt = FakeClient() + ctx = _ctx(src, tgt, dry_run=True) + report = MigrationReport() + + result = ConnectorPhase(ctx).run(report) + + assert result.skipped == 1 + assert result.created == 0 + assert tgt.posts == [] + + +def test_abort_on_name_conflict_raises(): + src = FakeClient([_src("src-a", "Prod PG")]) + tgt = FakeClient( + [ + { + "id": "preexisting", + "connector_id": "postgres|abc", + "connector_name": "Prod PG", + "connector_type": "INPUT", + "connector_metadata": {}, + } + ] + ) + ctx = _ctx(src, tgt, on_name_conflict="abort") + report = MigrationReport() + + with pytest.raises(NameConflictError): + ConnectorPhase(ctx).run(report) diff --git a/tests/migration/test_tag_phase.py b/tests/migration/test_tag_phase.py new file mode 100644 index 0000000..7628d5b --- /dev/null +++ b/tests/migration/test_tag_phase.py @@ -0,0 +1,104 @@ +"""Tests for ``TagPhase``. + +Tag is the simplest entity — no encryption, no list-vs-detail divergence. +Suite covers happy / idempotency / dry-run / abort. +""" + +from __future__ import annotations + +import pytest + +from unstract.migration.context import ( + MigrationContext, + MigrationOptions, + RemapTable, +) +from unstract.migration.exceptions import NameConflictError +from unstract.migration.phases.tag import TagPhase +from unstract.migration.report import MigrationReport + + +class FakeClient: + def __init__(self, tags: list[dict] | None = None): + self.tags: list[dict] = list(tags or []) + self.posts: list[dict] = [] + self._next_id = 1 + + def list_tags(self, *, name=None): + result = self.tags + if name is not None: + result = [t for t in result if t["name"] == name] + return list(result) + + def create_tag(self, payload): + new = dict(payload) + new["id"] = f"tgt-{self._next_id:08d}-0000-0000-0000-000000000000" + self._next_id += 1 + self.tags.append(new) + self.posts.append(new) + return new + + +def _src(id_, name): + return {"id": id_, "name": name, "description": f"{name} desc"} + + +def _ctx(source, target, **opt_overrides): + return MigrationContext( + source=source, + target=target, + options=MigrationOptions(**opt_overrides), + remap=RemapTable(), + ) + + +def test_happy_path_creates_all_and_records_remap(): + src = FakeClient([_src("src-a", "billing"), _src("src-b", "finance")]) + tgt = FakeClient() + ctx = _ctx(src, tgt) + report = MigrationReport() + + result = TagPhase(ctx).run(report) + + assert result.created == 2 + assert result.adopted == 0 + assert len(tgt.posts) == 2 + assert ctx.remap.resolve("tag", "src-a") == tgt.posts[0]["id"] + assert ctx.remap.resolve("tag", "src-b") == tgt.posts[1]["id"] + + +def test_idempotency_zero_creates_on_rerun(): + src = FakeClient([_src("src-a", "billing")]) + tgt = FakeClient([{"id": "preexisting", "name": "billing", "description": "x"}]) + ctx = _ctx(src, tgt, on_name_conflict="adopt") + report = MigrationReport() + + result = TagPhase(ctx).run(report) + + assert result.created == 0 + assert result.adopted == 1 + assert tgt.posts == [] + assert ctx.remap.resolve("tag", "src-a") == "preexisting" + + +def test_dry_run_makes_no_posts(): + src = FakeClient([_src("src-a", "billing")]) + tgt = FakeClient() + ctx = _ctx(src, tgt, dry_run=True) + report = MigrationReport() + + result = TagPhase(ctx).run(report) + + assert result.skipped == 1 + assert result.created == 0 + assert tgt.posts == [] + + +def test_abort_on_name_conflict_raises(): + src = FakeClient([_src("src-a", "billing")]) + tgt = FakeClient([{"id": "preexisting", "name": "billing", "description": "x"}]) + ctx = _ctx(src, tgt, on_name_conflict="abort") + report = MigrationReport() + + with pytest.raises(NameConflictError): + TagPhase(ctx).run(report) From 41022d50df5174716e6472c10dcf75ab0c583eeb Mon Sep 17 00:00:00 2001 From: Chandrasekharan M Date: Sun, 24 May 2026 02:05:58 +0530 Subject: [PATCH 04/25] refactor(migration): drive POST payloads from DRF OPTIONS schema MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Drops the hardcoded *_POST_FIELDS tuples and the UCS catalog-ID constant. Backend serializer is now the single source of truth for which fields the SDK posts: - PlatformClient.get_post_schema(entity_path) issues OPTIONS once per path, caches the writable-field set (DRF SimpleMetadata already strips read_only fields from actions.POST). - Each phase fetches its schema in run(); builds the POST body by intersecting the source GET payload with the schema. - ConnectorPhase: replace UCS_CONNECTOR_ID hardcode with an empty-metadata signal — backend redacts metadata to {} for auto-provisioned rows, so a falsy connector_metadata on the wire is unmigratable. Future redactions (any catalog) are covered automatically. Test FakeClients gain a get_post_schema() that mirrors the writable subset; UCS test renamed to test_redacted_metadata_connector_skipped. 22/22 unit tests pass. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/unstract/migration/client.py | 23 ++++++++++ src/unstract/migration/phases/adapter.py | 21 ++++----- src/unstract/migration/phases/connector.py | 52 +++++++++++----------- src/unstract/migration/phases/tag.py | 11 ++++- tests/migration/test_adapter_phase.py | 8 ++++ tests/migration/test_connector_phase.py | 29 +++++++++--- tests/migration/test_tag_phase.py | 5 +++ 7 files changed, 102 insertions(+), 47 deletions(-) diff --git a/src/unstract/migration/client.py b/src/unstract/migration/client.py index 007b318..af62e79 100644 --- a/src/unstract/migration/client.py +++ b/src/unstract/migration/client.py @@ -37,6 +37,9 @@ def __init__(self, endpoint: OrgEndpoint, timeout: int = DEFAULT_TIMEOUT, verify "Accept": "application/json", } ) + # Cache the OPTIONS-derived writable-field set per entity path. + # Backend serializer is the single source of truth; we read it once. + self._post_schema_cache: dict[str, frozenset[str]] = {} def _url(self, path: str) -> str: base = self.endpoint.base_url.rstrip("/") @@ -73,6 +76,26 @@ def _request( return None return resp.json() + def get_post_schema(self, entity_path: str) -> frozenset[str]: + """Return the set of fields the backend's POST serializer accepts. + + Reads it from a DRF ``OPTIONS`` response (``actions.POST``) once + per path and caches the result. DRF ``SimpleMetadata`` already + excludes ``read_only`` fields from ``actions.POST``, so the + returned set is exactly the writable subset. + """ + cached = self._post_schema_cache.get(entity_path) + if cached is not None: + return cached + body = self._request("OPTIONS", entity_path) + actions = (body or {}).get("actions") or {} + post_block = actions.get("POST") or {} + writable = frozenset( + name for name, meta in post_block.items() if not meta.get("read_only") + ) + self._post_schema_cache[entity_path] = writable + return writable + # ----- adapters ----- def list_adapters( diff --git a/src/unstract/migration/phases/adapter.py b/src/unstract/migration/phases/adapter.py index 3885cb6..5738cf3 100644 --- a/src/unstract/migration/phases/adapter.py +++ b/src/unstract/migration/phases/adapter.py @@ -20,17 +20,7 @@ logger = logging.getLogger(__name__) -# Fields copied verbatim from source GET into target POST. Everything else -# (id, created_by, deprecation flags, icon, etc.) is either auto-set by the -# target backend or derived — carrying it would either be ignored or cause -# validation noise. -ADAPTER_POST_FIELDS = ( - "adapter_id", - "adapter_name", - "adapter_type", - "adapter_metadata", - "description", -) +ADAPTER_PATH = "adapter/" class AdapterPhase(Phase): @@ -38,6 +28,13 @@ class AdapterPhase(Phase): def run(self, report: MigrationReport) -> PhaseResult: result = report.get_phase(self.name) + try: + self._writable = self.ctx.target.get_post_schema(ADAPTER_PATH) + except Exception as e: + logger.exception("Failed to fetch target POST schema for adapter: %s", e) + result.failed += 1 + result.errors.append(f"OPTIONS adapter: {e}") + return result try: src_summaries = self.ctx.source.list_adapters() except Exception as e: @@ -88,7 +85,7 @@ def _migrate_one(self, summary: dict[str, Any], result: PhaseResult) -> None: logger.info("[dry-run] would create adapter '%s' [%s] src=%s", name, atype, src_id) return else: - payload = {k: src[k] for k in ADAPTER_POST_FIELDS if k in src and src[k] is not None} + payload = {k: src[k] for k in self._writable if k in src and src[k] is not None} try: tgt = self.ctx.target.create_adapter(payload) except Exception as e: diff --git a/src/unstract/migration/phases/connector.py b/src/unstract/migration/phases/connector.py index 9794fc9..608e70a 100644 --- a/src/unstract/migration/phases/connector.py +++ b/src/unstract/migration/phases/connector.py @@ -3,12 +3,13 @@ Same list -> per-id GET -> POST/adopt pattern as AdapterPhase. Two connector-specific wrinkles: -1. **Auto-provisioned UCS connectors are skipped.** The Unstract Cloud - Storage connector has its ``connector_metadata`` redacted to ``{}`` - on the wire, so we cannot reliably reconstruct it on the target. - The target org is expected to have its own UCS row already; downstream - phases (workflow endpoints) must remap by ``connector_id`` lookup - rather than relying on the remap table here. +1. **Connectors with redacted metadata are skipped.** The backend + serializer strips ``connector_metadata`` for auto-provisioned rows + (e.g. Unstract Cloud Storage), so the SDK cannot reconstruct them + on the target. We detect this by inspecting the source GET response: + a falsy ``connector_metadata`` means the operator must rely on the + target's own provisioning (or re-create the row manually) — the + remap table records no entry for these. 2. **OAuth ``connector_auth`` is stripped from responses.** Tokens are stored in a sibling ``ConnectorAuth`` row that the public API never @@ -27,16 +28,7 @@ logger = logging.getLogger(__name__) -UCS_CONNECTOR_ID = "pcs|b8cd25cd-4452-4d54-bd5e-e7d71459b702" - -CONNECTOR_POST_FIELDS = ( - "connector_id", - "connector_name", - "connector_metadata", - "connector_version", - "connector_type", - "shared_to_org", -) +CONNECTOR_PATH = "connector/" class ConnectorPhase(Phase): @@ -44,6 +36,13 @@ class ConnectorPhase(Phase): def run(self, report: MigrationReport) -> PhaseResult: result = report.get_phase(self.name) + try: + self._writable = self.ctx.target.get_post_schema(CONNECTOR_PATH) + except Exception as e: + logger.exception("Failed to fetch target POST schema for connector: %s", e) + result.failed += 1 + result.errors.append(f"OPTIONS connector: {e}") + return result try: src_summaries = self.ctx.source.list_connectors() except Exception as e: @@ -60,15 +59,6 @@ def run(self, report: MigrationReport) -> PhaseResult: def _migrate_one(self, summary: dict[str, Any], result: PhaseResult) -> None: name = summary["connector_name"] src_id = summary["id"] - catalog_id = summary.get("connector_id") - - if catalog_id == UCS_CONNECTOR_ID: - logger.info( - "skipping UCS connector '%s' (src=%s) — auto-provisioned per-org", - name, src_id, - ) - result.skipped += 1 - return try: src = self.ctx.source.get_connector(src_id) @@ -78,6 +68,16 @@ def _migrate_one(self, summary: dict[str, Any], result: PhaseResult) -> None: result.errors.append(f"GET source detail {name}: {e}") return + # Empty metadata means the backend redacted it (auto-provisioned rows + # like Unstract Cloud Storage). We cannot reconstruct it on target. + if not src.get("connector_metadata"): + logger.info( + "skipping connector '%s' (src=%s, catalog=%s) — source returned no metadata", + name, src_id, src.get("connector_id"), + ) + result.skipped += 1 + return + try: existing = self.ctx.target.list_connectors(name=name) except Exception as e: @@ -102,7 +102,7 @@ def _migrate_one(self, summary: dict[str, Any], result: PhaseResult) -> None: logger.info("[dry-run] would create connector '%s' src=%s", name, src_id) return else: - payload = {k: src[k] for k in CONNECTOR_POST_FIELDS if k in src and src[k] is not None} + payload = {k: src[k] for k in self._writable if k in src and src[k] is not None} try: tgt = self.ctx.target.create_connector(payload) except Exception as e: diff --git a/src/unstract/migration/phases/tag.py b/src/unstract/migration/phases/tag.py index 1d6a81b..c32692a 100644 --- a/src/unstract/migration/phases/tag.py +++ b/src/unstract/migration/phases/tag.py @@ -19,7 +19,7 @@ logger = logging.getLogger(__name__) -TAG_POST_FIELDS = ("name", "description") +TAG_PATH = "tags/" class TagPhase(Phase): @@ -27,6 +27,13 @@ class TagPhase(Phase): def run(self, report: MigrationReport) -> PhaseResult: result = report.get_phase(self.name) + try: + self._writable = self.ctx.target.get_post_schema(TAG_PATH) + except Exception as e: + logger.exception("Failed to fetch target POST schema for tag: %s", e) + result.failed += 1 + result.errors.append(f"OPTIONS tag: {e}") + return result try: src_tags = self.ctx.source.list_tags() except Exception as e: @@ -65,7 +72,7 @@ def _migrate_one(self, src: dict[str, Any], result: PhaseResult) -> None: logger.info("[dry-run] would create tag '%s' src=%s", name, src_id) return else: - payload = {k: src[k] for k in TAG_POST_FIELDS if k in src and src[k] is not None} + payload = {k: src[k] for k in self._writable if k in src and src[k] is not None} try: tgt = self.ctx.target.create_tag(payload) except Exception as e: diff --git a/tests/migration/test_adapter_phase.py b/tests/migration/test_adapter_phase.py index d6827ff..4530a86 100644 --- a/tests/migration/test_adapter_phase.py +++ b/tests/migration/test_adapter_phase.py @@ -24,12 +24,20 @@ class FakeClient: """Minimal in-memory stand-in for ``PlatformClient``.""" + # Mirrors DRF OPTIONS actions.POST writable fields for adapter. + POST_SCHEMA = frozenset( + {"adapter_id", "adapter_name", "adapter_type", "adapter_metadata", "description"} + ) + def __init__(self, adapters: list[dict] | None = None): # Stored as a list of dicts; mutated by create_adapter. self.adapters: list[dict] = list(adapters or []) self.posts: list[dict] = [] self._next_id = 1 + def get_post_schema(self, entity_path): + return self.POST_SCHEMA + def list_adapters(self, *, name=None, adapter_type=None): result = self.adapters if name is not None: diff --git a/tests/migration/test_connector_phase.py b/tests/migration/test_connector_phase.py index e4341b3..a72a098 100644 --- a/tests/migration/test_connector_phase.py +++ b/tests/migration/test_connector_phase.py @@ -15,19 +15,31 @@ RemapTable, ) from unstract.migration.exceptions import NameConflictError -from unstract.migration.phases.connector import ( - UCS_CONNECTOR_ID, - ConnectorPhase, -) +from unstract.migration.phases.connector import ConnectorPhase from unstract.migration.report import MigrationReport class FakeClient: + POST_SCHEMA = frozenset( + { + "connector_id", + "connector_name", + "connector_metadata", + "connector_version", + "connector_mode", + "connector_type", + "shared_to_org", + } + ) + def __init__(self, connectors: list[dict] | None = None): self.connectors: list[dict] = list(connectors or []) self.posts: list[dict] = [] self._next_id = 1 + def get_post_schema(self, entity_path): + return self.POST_SCHEMA + def list_connectors(self, *, name=None, connector_type=None): result = self.connectors if name is not None: @@ -88,9 +100,12 @@ def test_happy_path_creates_all_and_records_remap(): assert ctx.remap.resolve("connector", "src-b") == tgt.posts[1]["id"] -def test_ucs_connector_skipped_without_target_lookup(): - """UCS rows must be skipped pre-flight — no POST, no remap entry.""" - src = FakeClient([_src("src-ucs", "User Storage", catalog_id=UCS_CONNECTOR_ID)]) +def test_redacted_metadata_connector_skipped(): + """Source returning empty metadata (redacted by backend) is unmigratable — + skipped with no POST and no remap entry.""" + redacted = _src("src-ucs", "User Storage") + redacted["connector_metadata"] = {} # backend redaction signal + src = FakeClient([redacted]) tgt = FakeClient() ctx = _ctx(src, tgt) report = MigrationReport() diff --git a/tests/migration/test_tag_phase.py b/tests/migration/test_tag_phase.py index 7628d5b..c9fc3e9 100644 --- a/tests/migration/test_tag_phase.py +++ b/tests/migration/test_tag_phase.py @@ -19,11 +19,16 @@ class FakeClient: + POST_SCHEMA = frozenset({"name", "description"}) + def __init__(self, tags: list[dict] | None = None): self.tags: list[dict] = list(tags or []) self.posts: list[dict] = [] self._next_id = 1 + def get_post_schema(self, entity_path): + return self.POST_SCHEMA + def list_tags(self, *, name=None): result = self.tags if name is not None: From fcd7048d65fc1c0431171c123b6c4ea6038100c9 Mon Sep 17 00:00:00 2001 From: Chandrasekharan M Date: Sun, 24 May 2026 02:11:06 +0530 Subject: [PATCH 05/25] refactor(migration): centralise POST payload construction MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add phases.base.build_post_payload(src, writable) used by every phase: - Subtracts SERVER_MANAGED from the OPTIONS-derived writable set. DRF exposes id/organization/created_by/modified_by/shared_users/timestamps as writable on ModelSerializer, but the view's perform_create overrides them server-side — posting them is noise (and a source-org value for organization/created_by would mismatch the target). - Skips empty strings as well as None. DRF treats '' on a required field as blank and 400s (hit on connector_version=''). Local smoke: 8/8 connectors + 2/2 tags migrated, idempotent re-run adopts all 10. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/unstract/migration/phases/adapter.py | 4 +-- src/unstract/migration/phases/base.py | 31 ++++++++++++++++++++++ src/unstract/migration/phases/connector.py | 4 +-- src/unstract/migration/phases/tag.py | 4 +-- 4 files changed, 37 insertions(+), 6 deletions(-) diff --git a/src/unstract/migration/phases/adapter.py b/src/unstract/migration/phases/adapter.py index 5738cf3..711ef8c 100644 --- a/src/unstract/migration/phases/adapter.py +++ b/src/unstract/migration/phases/adapter.py @@ -15,7 +15,7 @@ from typing import Any from unstract.migration.exceptions import NameConflictError -from unstract.migration.phases.base import Phase +from unstract.migration.phases.base import Phase, build_post_payload from unstract.migration.report import MigrationReport, PhaseResult logger = logging.getLogger(__name__) @@ -85,7 +85,7 @@ def _migrate_one(self, summary: dict[str, Any], result: PhaseResult) -> None: logger.info("[dry-run] would create adapter '%s' [%s] src=%s", name, atype, src_id) return else: - payload = {k: src[k] for k in self._writable if k in src and src[k] is not None} + payload = build_post_payload(src, self._writable) try: tgt = self.ctx.target.create_adapter(payload) except Exception as e: diff --git a/src/unstract/migration/phases/base.py b/src/unstract/migration/phases/base.py index c22f587..4b3e0ec 100644 --- a/src/unstract/migration/phases/base.py +++ b/src/unstract/migration/phases/base.py @@ -4,12 +4,43 @@ import logging from abc import ABC, abstractmethod +from typing import Any from unstract.migration.context import MigrationContext from unstract.migration.report import MigrationReport, PhaseResult logger = logging.getLogger(__name__) +# DRF OPTIONS reports any ModelSerializer FK/M2M as writable, but the +# backend's perform_create overrides these server-side. Posting them is +# either noise (silently overwritten) or a 400 (when a source-org value +# doesn't validate against the target org). Strip them universally — +# the phase OPTIONS schema covers the entity-specific writable subset. +SERVER_MANAGED: frozenset[str] = frozenset( + { + "id", + "organization", + "created_by", + "created_by_email", + "modified_by", + "modified_by_email", + "created_at", + "modified_at", + "shared_users", + } +) + + +def build_post_payload( + src: dict[str, Any], writable: frozenset[str] +) -> dict[str, Any]: + """Project ``src`` onto the writable schema, dropping server-managed + fields, ``None`` values, and empty strings (which DRF treats as blank + and rejects on required fields). + """ + keys = writable - SERVER_MANAGED + return {k: src[k] for k in keys if k in src and src[k] not in (None, "")} + class Phase(ABC): """Abstract phase. One subclass per entity type.""" diff --git a/src/unstract/migration/phases/connector.py b/src/unstract/migration/phases/connector.py index 608e70a..2e2ea33 100644 --- a/src/unstract/migration/phases/connector.py +++ b/src/unstract/migration/phases/connector.py @@ -23,7 +23,7 @@ from typing import Any from unstract.migration.exceptions import NameConflictError -from unstract.migration.phases.base import Phase +from unstract.migration.phases.base import Phase, build_post_payload from unstract.migration.report import MigrationReport, PhaseResult logger = logging.getLogger(__name__) @@ -102,7 +102,7 @@ def _migrate_one(self, summary: dict[str, Any], result: PhaseResult) -> None: logger.info("[dry-run] would create connector '%s' src=%s", name, src_id) return else: - payload = {k: src[k] for k in self._writable if k in src and src[k] is not None} + payload = build_post_payload(src, self._writable) try: tgt = self.ctx.target.create_connector(payload) except Exception as e: diff --git a/src/unstract/migration/phases/tag.py b/src/unstract/migration/phases/tag.py index c32692a..9381a26 100644 --- a/src/unstract/migration/phases/tag.py +++ b/src/unstract/migration/phases/tag.py @@ -14,7 +14,7 @@ from typing import Any from unstract.migration.exceptions import NameConflictError -from unstract.migration.phases.base import Phase +from unstract.migration.phases.base import Phase, build_post_payload from unstract.migration.report import MigrationReport, PhaseResult logger = logging.getLogger(__name__) @@ -72,7 +72,7 @@ def _migrate_one(self, src: dict[str, Any], result: PhaseResult) -> None: logger.info("[dry-run] would create tag '%s' src=%s", name, src_id) return else: - payload = {k: src[k] for k in self._writable if k in src and src[k] is not None} + payload = build_post_payload(src, self._writable) try: tgt = self.ctx.target.create_tag(payload) except Exception as e: From 8863f29d94df72f7e9ae87f3d84f50395219dffa Mon Sep 17 00:00:00 2001 From: Chandrasekharan M Date: Sun, 24 May 2026 02:36:35 +0530 Subject: [PATCH 06/25] feat(migration): CustomTool composite phase Migrates prompt-studio projects with their ProfileManager + ToolStudioPrompt children, then republishes PromptStudioRegistry via the backend's export-tool action (avoids carrying tool_metadata across orgs). Walker-remaps adapter UUIDs into profile FKs and across embedded JSON fields. Fresh-tool path deletes the backend's auto-default profile before recreating from source. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/unstract/migration/client.py | 70 ++++ src/unstract/migration/orchestrator.py | 8 +- src/unstract/migration/phases/__init__.py | 3 +- src/unstract/migration/phases/custom_tool.py | 316 +++++++++++++++++ tests/migration/test_custom_tool_phase.py | 342 +++++++++++++++++++ 5 files changed, 737 insertions(+), 2 deletions(-) create mode 100644 src/unstract/migration/phases/custom_tool.py create mode 100644 tests/migration/test_custom_tool_phase.py diff --git a/src/unstract/migration/client.py b/src/unstract/migration/client.py index af62e79..5b4b414 100644 --- a/src/unstract/migration/client.py +++ b/src/unstract/migration/client.py @@ -156,3 +156,73 @@ def list_tags(self, *, name: str | None = None) -> list[dict[str, Any]]: def create_tag(self, payload: dict[str, Any]) -> dict[str, Any]: return self._request("POST", "tags/", json=payload) + + # ----- custom tools (prompt studio) ----- + + def list_custom_tools(self) -> list[dict[str, Any]]: + """List all prompt-studio projects in this org. No name filter.""" + result = self._request("GET", "prompt-studio/") + return result if isinstance(result, list) else result.get("results", []) + + def get_custom_tool(self, tool_id: str) -> dict[str, Any]: + """Tool detail; response includes embedded ``prompts`` + ``default_profile``.""" + return self._request("GET", f"prompt-studio/{tool_id}/") + + def create_custom_tool(self, payload: dict[str, Any]) -> dict[str, Any]: + """Create a custom tool. Backend also auto-creates one default ProfileManager.""" + return self._request("POST", "prompt-studio/", json=payload) + + def export_custom_tool(self, tool_id: str, *, force: bool = True) -> Any: + """Republish ``PromptStudioRegistry`` from the tool's current target state. + + Used after profile+prompt reconciliation so the registry row is + rebuilt without the SDK ever carrying ``tool_metadata`` across orgs. + """ + return self._request( + "POST", + f"prompt-studio/export/{tool_id}", + json={ + "is_shared_with_org": False, + "user_id": [], + "force_export": force, + }, + ) + + # ----- profile managers ----- + + def list_profiles(self, tool_id: str) -> list[dict[str, Any]]: + """List ProfileManager rows for a tool via the per-tool list action.""" + result = self._request( + "GET", f"prompt-studio/prompt-studio-profile/{tool_id}/" + ) + return result if isinstance(result, list) else result.get("results", []) + + def create_profile(self, tool_id: str, payload: dict[str, Any]) -> dict[str, Any]: + """POST to ``prompt-studio/profilemanager/{tool_id}`` (no trailing slash).""" + return self._request( + "POST", f"prompt-studio/profilemanager/{tool_id}", json=payload + ) + + def delete_profile(self, profile_id: str) -> None: + self._request("DELETE", f"profile-manager/{profile_id}/") + + def set_default_profile(self, tool_id: str, profile_id: str) -> Any: + """Mark a single profile as default for this tool (zeros the rest).""" + return self._request( + "PATCH", + f"prompt-studio/prompt-studio-profile/{tool_id}/", + json={"default_profile": profile_id}, + ) + + # ----- prompts ----- + + def list_prompts(self, *, tool_id: str) -> list[dict[str, Any]]: + """List prompts filtered by tool_id (FilterHelper-backed).""" + result = self._request("GET", "prompt/", params={"tool_id": tool_id}) + return result if isinstance(result, list) else result.get("results", []) + + def create_prompt(self, tool_id: str, payload: dict[str, Any]) -> dict[str, Any]: + """POST to ``prompt-studio/prompt-studio-prompt/{tool_id}/`` (create_prompt action).""" + return self._request( + "POST", f"prompt-studio/prompt-studio-prompt/{tool_id}/", json=payload + ) diff --git a/src/unstract/migration/orchestrator.py b/src/unstract/migration/orchestrator.py index f3e78b6..59e506b 100644 --- a/src/unstract/migration/orchestrator.py +++ b/src/unstract/migration/orchestrator.py @@ -16,7 +16,12 @@ from unstract.migration.client import PlatformClient from unstract.migration.context import MigrationContext, MigrationOptions, OrgEndpoint from unstract.migration.exceptions import MigrationError -from unstract.migration.phases import AdapterPhase, ConnectorPhase, TagPhase +from unstract.migration.phases import ( + AdapterPhase, + ConnectorPhase, + CustomToolPhase, + TagPhase, +) from unstract.migration.phases.base import Phase from unstract.migration.report import MigrationReport @@ -30,6 +35,7 @@ ("adapter", AdapterPhase), ("connector", ConnectorPhase), ("tag", TagPhase), + ("custom_tool", CustomToolPhase), ] diff --git a/src/unstract/migration/phases/__init__.py b/src/unstract/migration/phases/__init__.py index ff6cac5..3462a05 100644 --- a/src/unstract/migration/phases/__init__.py +++ b/src/unstract/migration/phases/__init__.py @@ -10,6 +10,7 @@ from unstract.migration.phases.adapter import AdapterPhase from unstract.migration.phases.base import Phase from unstract.migration.phases.connector import ConnectorPhase +from unstract.migration.phases.custom_tool import CustomToolPhase from unstract.migration.phases.tag import TagPhase -__all__ = ["AdapterPhase", "ConnectorPhase", "Phase", "TagPhase"] +__all__ = ["AdapterPhase", "ConnectorPhase", "CustomToolPhase", "Phase", "TagPhase"] diff --git a/src/unstract/migration/phases/custom_tool.py b/src/unstract/migration/phases/custom_tool.py new file mode 100644 index 0000000..d32a855 --- /dev/null +++ b/src/unstract/migration/phases/custom_tool.py @@ -0,0 +1,316 @@ +"""Migrate prompt-studio projects (CustomTool) and their children. + +Composite phase: a single project carries ``ProfileManager`` rows (LLM +triad config) and ``ToolStudioPrompt`` rows (the actual prompts). All +three must land together for the project to be functional on target, so +they live in one phase rather than three sibling phases. + +Within a project, the create order is: + + 1. CustomTool — POST creates the project and auto-creates one default + ProfileManager on target. + 2. ProfileManagers — on a freshly-created tool we delete the auto-default + first so the source's profiles land cleanly. On an adopted tool we + reconcile by ``profile_name`` (per-tool unique). + 3. ToolStudioPrompts — reconcile by ``prompt_key`` (per-tool unique). + 4. Republish PromptStudioRegistry via the ``export-tool`` action so the + registry row is rebuilt server-side from the now-correct child state. + Avoids the SDK carrying ``tool_metadata`` JSON across orgs. + +Walker remapping: adapter UUIDs embedded in the tool's adapter FKs +(``monitor_llm``, ``challenge_llm``, ``summarize_llm_adapter``), in the +profile's adapter FKs (``llm``, ``embedding_model``, ``vector_store``, +``x2text``), and in the prompt's ``profile_manager`` + ``tool_id`` FKs +are remapped before POST using the running ``RemapTable``. + +The ProfileManager GET response expands adapter FKs into nested adapter +dicts (per the backend serializer's ``to_representation``); we flatten +them back to UUIDs before walker pass. +""" + +from __future__ import annotations + +import logging +from typing import Any + +from unstract.migration.exceptions import NameConflictError +from unstract.migration.phases.base import SERVER_MANAGED, Phase, build_post_payload +from unstract.migration.report import MigrationReport, PhaseResult +from unstract.migration.walker import remap_uuids + +logger = logging.getLogger(__name__) + +TOOL_PATH = "prompt-studio/" + +# Per-action endpoints on PromptStudioCoreView don't surface their own +# DRF metadata (OPTIONS returns the parent CustomToolSerializer schema). +# Hardcode the model-derived writable subset for the children and let the +# integration test catch backend drift. +PROFILE_WRITABLE: frozenset[str] = frozenset( + { + "profile_name", + "vector_store", + "embedding_model", + "llm", + "x2text", + "chunk_size", + "chunk_overlap", + "reindex", + "retrieval_strategy", + "similarity_top_k", + "section", + "prompt_studio_tool", + "is_default", + "is_summarize_llm", + } +) + +PROMPT_WRITABLE: frozenset[str] = frozenset( + { + "prompt_key", + "enforce_type", + "prompt", + "tool_id", + "sequence_number", + "prompt_type", + "profile_manager", + "output", + "assert_prompt", + "assertion_failure_prompt", + "required", + "is_assert", + "active", + "output_metadata", + "postprocessing_webhook_url", + "evaluate", + "eval_quality_faithfulness", + "eval_quality_correctness", + "eval_quality_relevance", + "eval_security_pii", + "eval_guidance_toxicity", + "eval_guidance_completeness", + } +) + +_PROFILE_ADAPTER_KEYS = ("llm", "embedding_model", "vector_store", "x2text") + + +def _flatten_profile_adapters(profile: dict[str, Any]) -> dict[str, Any]: + """ProfileManagerSerializer.to_representation expands FK adapters into + nested dicts; for write paths we need flat UUIDs back. + """ + out = dict(profile) + for key in _PROFILE_ADAPTER_KEYS: + val = out.get(key) + if isinstance(val, dict) and "id" in val: + out[key] = val["id"] + return out + + +class CustomToolPhase(Phase): + name = "custom_tool" + + def run(self, report: MigrationReport) -> PhaseResult: + result = report.get_phase(self.name) + try: + self._tool_writable = self.ctx.target.get_post_schema(TOOL_PATH) + except Exception as e: + logger.exception("Failed to fetch target POST schema for prompt-studio: %s", e) + result.failed += 1 + result.errors.append(f"OPTIONS prompt-studio: {e}") + return result + + try: + src_tools = self.ctx.source.list_custom_tools() + except Exception as e: + logger.exception("Failed to list source custom tools: %s", e) + result.failed += 1 + result.errors.append(f"list source custom tools: {e}") + return result + + logger.info("Found %d custom tool(s) in source org", len(src_tools)) + for summary in src_tools: + self._migrate_one(summary, result) + return result + + def _migrate_one(self, summary: dict[str, Any], result: PhaseResult) -> None: + tool_name = summary["tool_name"] + src_tool_id = summary["tool_id"] + + try: + src_tool = self.ctx.source.get_custom_tool(src_tool_id) + except Exception as e: + logger.exception("Failed to GET source tool %s: %s", tool_name, e) + result.failed += 1 + result.errors.append(f"GET source tool {tool_name}: {e}") + return + + tgt_tool, fresh = self._get_or_create_tool(src_tool, result) + if tgt_tool is None: + return + + tgt_tool_id = tgt_tool["tool_id"] + self.ctx.remap.record("custom_tool", src_tool_id, tgt_tool_id) + + if self.ctx.options.dry_run: + logger.info( + "[dry-run] would reconcile profiles+prompts for tool '%s' src=%s", + tool_name, src_tool_id, + ) + return + + try: + src_profiles = self.ctx.source.list_profiles(src_tool_id) + except Exception as e: + logger.exception("Failed to list source profiles for %s: %s", tool_name, e) + result.failed += 1 + result.errors.append(f"list src profiles {tool_name}: {e}") + return + + try: + self._reconcile_profiles(src_profiles, tgt_tool_id, fresh) + except Exception as e: + logger.exception("Profile reconcile failed for tool %s: %s", tool_name, e) + result.failed += 1 + result.errors.append(f"profiles {tool_name}: {e}") + return + + try: + src_prompts = src_tool.get("prompts") or [] + self._reconcile_prompts(src_prompts, tgt_tool_id) + except Exception as e: + logger.exception("Prompt reconcile failed for tool %s: %s", tool_name, e) + result.failed += 1 + result.errors.append(f"prompts {tool_name}: {e}") + return + + try: + self.ctx.target.export_custom_tool(tgt_tool_id) + logger.info("republished registry for tool '%s' tgt=%s", tool_name, tgt_tool_id) + except Exception as e: + logger.exception("Registry republish failed for tool %s: %s", tool_name, e) + result.failed += 1 + result.errors.append(f"export {tool_name}: {e}") + + def _get_or_create_tool( + self, src_tool: dict[str, Any], result: PhaseResult + ) -> tuple[dict[str, Any] | None, bool]: + tool_name = src_tool["tool_name"] + src_tool_id = src_tool["tool_id"] + + try: + target_tools = self.ctx.target.list_custom_tools() + except Exception as e: + logger.exception("Failed to list target tools: %s", e) + result.failed += 1 + result.errors.append(f"list target tools: {e}") + return None, False + + match = next((t for t in target_tools if t["tool_name"] == tool_name), None) + if match is not None: + if self.ctx.options.on_name_conflict == "abort": + raise NameConflictError( + f"tool '{tool_name}' already exists in target as {match['tool_id']}" + ) + result.adopted += 1 + logger.info( + "adopted tool '%s' src=%s -> tgt=%s", + tool_name, src_tool_id, match["tool_id"], + ) + return match, False + + if self.ctx.options.dry_run: + result.skipped += 1 + logger.info("[dry-run] would create tool '%s' src=%s", tool_name, src_tool_id) + return None, True + + remapped = remap_uuids(src_tool, self.ctx.remap) + payload = build_post_payload(remapped, self._tool_writable) + try: + tgt = self.ctx.target.create_custom_tool(payload) + except Exception as e: + logger.exception("Failed to create tool %s: %s", tool_name, e) + result.failed += 1 + result.errors.append(f"create tool {tool_name}: {e}") + return None, True + result.created += 1 + logger.info( + "created tool '%s' src=%s -> tgt=%s", + tool_name, src_tool_id, tgt["tool_id"], + ) + return tgt, True + + def _reconcile_profiles( + self, + src_profiles: list[dict[str, Any]], + tgt_tool_id: str, + fresh: bool, + ) -> None: + if fresh: + for p in self.ctx.target.list_profiles(tgt_tool_id): + self.ctx.target.delete_profile(p["profile_id"]) + logger.debug("deleted auto-default profile %s", p["profile_id"]) + + src_default_id: str | None = None + for src_profile in src_profiles: + src_pid = src_profile["profile_id"] + if src_profile.get("is_default"): + src_default_id = src_pid + + target_profiles_by_name = { + p["profile_name"]: p + for p in self.ctx.target.list_profiles(tgt_tool_id) + } + existing = target_profiles_by_name.get(src_profile["profile_name"]) + + if existing is not None: + tgt_pid = existing["profile_id"] + logger.info( + "adopted profile '%s' src=%s -> tgt=%s", + src_profile["profile_name"], src_pid, tgt_pid, + ) + else: + flat = _flatten_profile_adapters(src_profile) + remapped = remap_uuids(flat, self.ctx.remap) + remapped["prompt_studio_tool"] = tgt_tool_id + payload = build_post_payload(remapped, PROFILE_WRITABLE) + tgt = self.ctx.target.create_profile(tgt_tool_id, payload) + tgt_pid = tgt["profile_id"] + logger.info( + "created profile '%s' src=%s -> tgt=%s", + src_profile["profile_name"], src_pid, tgt_pid, + ) + self.ctx.remap.record("profile_manager", src_pid, tgt_pid) + + if src_default_id: + tgt_default = self.ctx.remap.resolve("profile_manager", src_default_id) + if tgt_default: + self.ctx.target.set_default_profile(tgt_tool_id, tgt_default) + + def _reconcile_prompts( + self, src_prompts: list[dict[str, Any]], tgt_tool_id: str + ) -> None: + existing_prompts = self.ctx.target.list_prompts(tool_id=tgt_tool_id) + by_key = {p["prompt_key"]: p for p in existing_prompts} + + for src_prompt in src_prompts: + src_prompt_id = src_prompt["prompt_id"] + key = src_prompt["prompt_key"] + existing = by_key.get(key) + if existing is not None: + tgt_pid = existing["prompt_id"] + logger.info( + "adopted prompt '%s' src=%s -> tgt=%s", + key, src_prompt_id, tgt_pid, + ) + else: + remapped = remap_uuids(src_prompt, self.ctx.remap) + remapped["tool_id"] = tgt_tool_id + payload = build_post_payload(remapped, PROMPT_WRITABLE - SERVER_MANAGED) + tgt = self.ctx.target.create_prompt(tgt_tool_id, payload) + tgt_pid = tgt["prompt_id"] + logger.info( + "created prompt '%s' src=%s -> tgt=%s", + key, src_prompt_id, tgt_pid, + ) + self.ctx.remap.record("prompt", src_prompt_id, tgt_pid) diff --git a/tests/migration/test_custom_tool_phase.py b/tests/migration/test_custom_tool_phase.py new file mode 100644 index 0000000..0b30863 --- /dev/null +++ b/tests/migration/test_custom_tool_phase.py @@ -0,0 +1,342 @@ +"""Tests for ``CustomToolPhase`` — the composite tool + profile + prompt phase. + +Coverage: +- happy path: fresh tool → tool created, auto-default deleted, source + profiles + prompts created, registry republished. +- idempotency: re-run is a no-op when tool + profile names + prompt keys + already match on target. +- adopt path: existing tool on target, partial overlap of profiles — + matching ones adopted, missing ones created. +- dry-run: nothing posted. +- adapter UUID remap into profile FKs. +""" + +from __future__ import annotations + +from unstract.migration.context import ( + MigrationContext, + MigrationOptions, + RemapTable, +) +from unstract.migration.phases.custom_tool import CustomToolPhase +from unstract.migration.report import MigrationReport + + +TOOL_POST_SCHEMA = frozenset( + { + "tool_name", + "description", + "author", + "icon", + "preamble", + "postamble", + "prompt_grammer", + "monitor_llm", + "challenge_llm", + "summarize_llm_adapter", + "custom_data", + "single_pass_extraction_mode", + "shared_users", + "shared_to_org", + } +) + + +class FakeClient: + """In-memory stand-in for ``PlatformClient`` covering the prompt-studio surface.""" + + def __init__(self) -> None: + self.tools: dict[str, dict] = {} + self.profiles_by_tool: dict[str, list[dict]] = {} + self.prompts_by_tool: dict[str, list[dict]] = {} + self.export_calls: list[str] = [] + self._next = 1 + + # --- ID helper --- + def _mint(self, prefix: str) -> str: + s = f"tgt-{prefix}-{self._next:04d}" + self._next += 1 + return s + + # --- schema --- + def get_post_schema(self, entity_path: str) -> frozenset[str]: + if entity_path == "prompt-studio/": + return TOOL_POST_SCHEMA + raise AssertionError(f"unexpected OPTIONS path: {entity_path}") + + # --- tools --- + def list_custom_tools(self) -> list[dict]: + return list(self.tools.values()) + + def get_custom_tool(self, tool_id: str) -> dict: + return self.tools[tool_id] + + def create_custom_tool(self, payload: dict) -> dict: + tool_id = self._mint("tool") + tool = {**payload, "tool_id": tool_id, "prompts": []} + self.tools[tool_id] = tool + # Backend auto-creates a default profile on create. + auto = { + "profile_id": self._mint("autoprofile"), + "profile_name": "Default", + "is_default": True, + "prompt_studio_tool": tool_id, + } + self.profiles_by_tool[tool_id] = [auto] + self.prompts_by_tool[tool_id] = [] + return tool + + def export_custom_tool(self, tool_id: str, *, force: bool = True) -> None: + self.export_calls.append(tool_id) + + # --- profiles --- + def list_profiles(self, tool_id: str) -> list[dict]: + return list(self.profiles_by_tool.get(tool_id, [])) + + def create_profile(self, tool_id: str, payload: dict) -> dict: + new = {**payload, "profile_id": self._mint("profile")} + self.profiles_by_tool.setdefault(tool_id, []).append(new) + return new + + def delete_profile(self, profile_id: str) -> None: + for tid, profiles in self.profiles_by_tool.items(): + self.profiles_by_tool[tid] = [ + p for p in profiles if p["profile_id"] != profile_id + ] + + def set_default_profile(self, tool_id: str, profile_id: str) -> None: + for p in self.profiles_by_tool.get(tool_id, []): + p["is_default"] = p["profile_id"] == profile_id + + # --- prompts --- + def list_prompts(self, *, tool_id: str) -> list[dict]: + return list(self.prompts_by_tool.get(tool_id, [])) + + def create_prompt(self, tool_id: str, payload: dict) -> dict: + new = {**payload, "prompt_id": self._mint("prompt")} + self.prompts_by_tool.setdefault(tool_id, []).append(new) + return new + + +def _ctx(source, target, *, remap=None, **opt_overrides) -> MigrationContext: + return MigrationContext( + source=source, + target=target, + options=MigrationOptions(**opt_overrides), + remap=remap or RemapTable(), + ) + + +def _src_tool(tool_id: str, name: str, prompts: list[dict] | None = None) -> dict: + return { + "tool_id": tool_id, + "tool_name": name, + "description": f"{name} desc", + "author": "src", + "icon": "", + "preamble": "", + "postamble": "", + "prompt_grammer": {}, + "monitor_llm": None, + "challenge_llm": None, + "summarize_llm_adapter": None, + "custom_data": {}, + "single_pass_extraction_mode": False, + "shared_users": [], + "shared_to_org": False, + "prompts": prompts or [], + } + + +def _src_profile(pid: str, name: str, *, is_default: bool = False) -> dict: + return { + "profile_id": pid, + "profile_name": name, + "is_default": is_default, + "is_summarize_llm": False, + # Mimic to_representation expansion: nested adapter dicts. + "llm": {"id": "11111111-1111-1111-1111-111111111111", "adapter_name": "L"}, + "embedding_model": {"id": "22222222-2222-2222-2222-222222222222", "adapter_name": "E"}, + "vector_store": {"id": "33333333-3333-3333-3333-333333333333", "adapter_name": "V"}, + "x2text": {"id": "44444444-4444-4444-4444-444444444444", "adapter_name": "X"}, + "chunk_size": 1024, + "chunk_overlap": 128, + "reindex": False, + "retrieval_strategy": "simple", + "similarity_top_k": 3, + "section": "default", + "prompt_studio_tool": None, + } + + +def _src_prompt(prompt_id: str, key: str, profile_id: str) -> dict: + return { + "prompt_id": prompt_id, + "prompt_key": key, + "prompt": f"What is {key}?", + "enforce_type": "string", + "prompt_type": "prompt", + "sequence_number": 1, + "tool_id": "src-tool-x", + "profile_manager": profile_id, + "output": "", + "active": True, + "required": False, + } + + +def _preload_source(client: FakeClient, tool_id: str) -> dict: + """Helper to set up a source FakeClient with one tool, one profile, one prompt.""" + profile = _src_profile("src-profile-1", "Default", is_default=True) + prompt = _src_prompt("src-prompt-1", "field_a", "src-profile-1") + tool = _src_tool(tool_id, "Invoice Extractor", prompts=[prompt]) + client.tools[tool_id] = tool + client.profiles_by_tool[tool_id] = [profile] + client.prompts_by_tool[tool_id] = [prompt] + return tool + + +def _preload_remap_with_adapters(remap: RemapTable) -> None: + remap.record("adapter", "11111111-1111-1111-1111-111111111111", "a1111111-1111-1111-1111-111111111111") + remap.record("adapter", "22222222-2222-2222-2222-222222222222", "a2222222-2222-2222-2222-222222222222") + remap.record("adapter", "33333333-3333-3333-3333-333333333333", "a3333333-3333-3333-3333-333333333333") + remap.record("adapter", "44444444-4444-4444-4444-444444444444", "a4444444-4444-4444-4444-444444444444") + + +def test_fresh_tool_creates_tool_profiles_prompts_and_republishes(): + src = FakeClient() + tgt = FakeClient() + _preload_source(src, "src-tool-x") + remap = RemapTable() + _preload_remap_with_adapters(remap) + ctx = _ctx(src, tgt, remap=remap) + report = MigrationReport() + + result = CustomToolPhase(ctx).run(report) + + assert result.created == 1 + assert result.failed == 0 + assert len(tgt.tools) == 1 + tgt_tool_id = next(iter(tgt.tools)) + + # Auto-default profile deleted; exactly one profile (the source's) remains. + profiles = tgt.profiles_by_tool[tgt_tool_id] + assert len(profiles) == 1 + profile = profiles[0] + assert profile["profile_name"] == "Default" + assert profile["is_default"] is True + # Adapter FKs remapped via walker. + assert profile["llm"] == "a1111111-1111-1111-1111-111111111111" + assert profile["embedding_model"] == "a2222222-2222-2222-2222-222222222222" + assert profile["vector_store"] == "a3333333-3333-3333-3333-333333333333" + assert profile["x2text"] == "a4444444-4444-4444-4444-444444444444" + + # One prompt landed, pointing at the new tool. + prompts = tgt.prompts_by_tool[tgt_tool_id] + assert len(prompts) == 1 + assert prompts[0]["prompt_key"] == "field_a" + assert prompts[0]["tool_id"] == tgt_tool_id + + # Registry republished exactly once. + assert tgt.export_calls == [tgt_tool_id] + + # Remap records populated for downstream phases. + assert ctx.remap.resolve("custom_tool", "src-tool-x") == tgt_tool_id + assert ctx.remap.resolve("profile_manager", "src-profile-1") == profile["profile_id"] + assert ctx.remap.resolve("prompt", "src-prompt-1") == prompts[0]["prompt_id"] + + +def test_idempotent_rerun_does_not_create_duplicates(): + src = FakeClient() + tgt = FakeClient() + _preload_source(src, "src-tool-x") + remap = RemapTable() + _preload_remap_with_adapters(remap) + ctx = _ctx(src, tgt, remap=remap) + + CustomToolPhase(ctx).run(MigrationReport()) + tgt_tool_id = next(iter(tgt.tools)) + profile_count = len(tgt.profiles_by_tool[tgt_tool_id]) + prompt_count = len(tgt.prompts_by_tool[tgt_tool_id]) + export_count = len(tgt.export_calls) + + report2 = MigrationReport() + result2 = CustomToolPhase(ctx).run(report2) + + assert result2.adopted == 1 + assert result2.created == 0 + assert len(tgt.profiles_by_tool[tgt_tool_id]) == profile_count + assert len(tgt.prompts_by_tool[tgt_tool_id]) == prompt_count + # Republish still fires (rebuild registry idempotently). + assert len(tgt.export_calls) == export_count + 1 + + +def test_adopt_path_fills_missing_profile_only(): + src = FakeClient() + tgt = FakeClient() + # Source has TWO profiles. + extra = _src_profile("src-profile-2", "HighRecall", is_default=False) + default = _src_profile("src-profile-1", "Default", is_default=True) + prompt = _src_prompt("src-prompt-1", "field_a", "src-profile-1") + tool = _src_tool("src-tool-x", "Invoice Extractor", prompts=[prompt]) + src.tools["src-tool-x"] = tool + src.profiles_by_tool["src-tool-x"] = [default, extra] + src.prompts_by_tool["src-tool-x"] = [prompt] + + # Target already has the tool + the "Default" profile + the prompt. + tgt_tool_id = "tgt-pre-tool" + tgt.tools[tgt_tool_id] = { + "tool_id": tgt_tool_id, + "tool_name": "Invoice Extractor", + "prompts": [], + } + tgt.profiles_by_tool[tgt_tool_id] = [ + { + "profile_id": "tgt-pre-profile", + "profile_name": "Default", + "is_default": True, + "prompt_studio_tool": tgt_tool_id, + } + ] + tgt.prompts_by_tool[tgt_tool_id] = [ + { + "prompt_id": "tgt-pre-prompt", + "prompt_key": "field_a", + "tool_id": tgt_tool_id, + } + ] + + remap = RemapTable() + _preload_remap_with_adapters(remap) + ctx = _ctx(src, tgt, remap=remap) + + result = CustomToolPhase(ctx).run(MigrationReport()) + + assert result.adopted == 1 + # Adopted tool path: only the missing "HighRecall" profile got created. + profiles = tgt.profiles_by_tool[tgt_tool_id] + assert len(profiles) == 2 + names = {p["profile_name"] for p in profiles} + assert names == {"Default", "HighRecall"} + + # Prompt was already there → adopted, not duplicated. + assert len(tgt.prompts_by_tool[tgt_tool_id]) == 1 + + assert ctx.remap.resolve("profile_manager", "src-profile-1") == "tgt-pre-profile" + assert ctx.remap.resolve("prompt", "src-prompt-1") == "tgt-pre-prompt" + + +def test_dry_run_creates_nothing(): + src = FakeClient() + tgt = FakeClient() + _preload_source(src, "src-tool-x") + remap = RemapTable() + _preload_remap_with_adapters(remap) + ctx = _ctx(src, tgt, remap=remap, dry_run=True) + + result = CustomToolPhase(ctx).run(MigrationReport()) + + assert result.skipped == 1 + assert tgt.tools == {} + assert tgt.profiles_by_tool == {} + assert tgt.export_calls == [] From 7fe9c33c930293ae36e493d69d06250f1fad09cc Mon Sep 17 00:00:00 2001 From: Chandrasekharan M Date: Sun, 24 May 2026 02:39:30 +0530 Subject: [PATCH 07/25] feat(migration): WorkflowPhase Migrates Workflow rows with walker-remapped connector UUIDs in source_settings + destination_settings JSON blobs. WorkflowEndpoints are auto-created by the backend on workflow POST; reconciled later by the dedicated WorkflowEndpoint phase. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/unstract/migration/client.py | 17 +++ src/unstract/migration/orchestrator.py | 2 + src/unstract/migration/phases/__init__.py | 10 +- src/unstract/migration/phases/workflow.py | 94 +++++++++++++ tests/migration/test_workflow_phase.py | 155 ++++++++++++++++++++++ 5 files changed, 277 insertions(+), 1 deletion(-) create mode 100644 src/unstract/migration/phases/workflow.py create mode 100644 tests/migration/test_workflow_phase.py diff --git a/src/unstract/migration/client.py b/src/unstract/migration/client.py index 5b4b414..b456e69 100644 --- a/src/unstract/migration/client.py +++ b/src/unstract/migration/client.py @@ -226,3 +226,20 @@ def create_prompt(self, tool_id: str, payload: dict[str, Any]) -> dict[str, Any] return self._request( "POST", f"prompt-studio/prompt-studio-prompt/{tool_id}/", json=payload ) + + # ----- workflows ----- + + def list_workflows(self, *, name: str | None = None) -> list[dict[str, Any]]: + """List workflows in this org, optionally filtered by exact name.""" + params: dict[str, Any] = {} + if name is not None: + params["workflow_name"] = name + result = self._request("GET", "workflow/", params=params) + return result if isinstance(result, list) else result.get("results", []) + + def get_workflow(self, workflow_id: str) -> dict[str, Any]: + return self._request("GET", f"workflow/{workflow_id}/") + + def create_workflow(self, payload: dict[str, Any]) -> dict[str, Any]: + """Create a workflow. Backend auto-creates empty WorkflowEndpoints for it.""" + return self._request("POST", "workflow/", json=payload) diff --git a/src/unstract/migration/orchestrator.py b/src/unstract/migration/orchestrator.py index 59e506b..61773eb 100644 --- a/src/unstract/migration/orchestrator.py +++ b/src/unstract/migration/orchestrator.py @@ -21,6 +21,7 @@ ConnectorPhase, CustomToolPhase, TagPhase, + WorkflowPhase, ) from unstract.migration.phases.base import Phase from unstract.migration.report import MigrationReport @@ -36,6 +37,7 @@ ("connector", ConnectorPhase), ("tag", TagPhase), ("custom_tool", CustomToolPhase), + ("workflow", WorkflowPhase), ] diff --git a/src/unstract/migration/phases/__init__.py b/src/unstract/migration/phases/__init__.py index 3462a05..b8b4ab9 100644 --- a/src/unstract/migration/phases/__init__.py +++ b/src/unstract/migration/phases/__init__.py @@ -12,5 +12,13 @@ from unstract.migration.phases.connector import ConnectorPhase from unstract.migration.phases.custom_tool import CustomToolPhase from unstract.migration.phases.tag import TagPhase +from unstract.migration.phases.workflow import WorkflowPhase -__all__ = ["AdapterPhase", "ConnectorPhase", "CustomToolPhase", "Phase", "TagPhase"] +__all__ = [ + "AdapterPhase", + "ConnectorPhase", + "CustomToolPhase", + "Phase", + "TagPhase", + "WorkflowPhase", +] diff --git a/src/unstract/migration/phases/workflow.py b/src/unstract/migration/phases/workflow.py new file mode 100644 index 0000000..49917ee --- /dev/null +++ b/src/unstract/migration/phases/workflow.py @@ -0,0 +1,94 @@ +"""Migrate workflows from source org to target org. + +Workflow rows themselves are simple — no required FKs to migration +entities, unique per ``(workflow_name, organization)``. The two +non-trivial bits: + +1. ``source_settings`` and ``destination_settings`` are JSON blobs that + embed connector UUIDs. The walker remaps them using the running + ``RemapTable`` (connectors already landed in the previous phase). + +2. Creating a workflow auto-creates empty ``WorkflowEndpoint`` rows + server-side. We don't touch those here — the dedicated + WorkflowEndpoint phase reconciles them after ToolInstance lands. +""" + +from __future__ import annotations + +import logging +from typing import Any + +from unstract.migration.exceptions import NameConflictError +from unstract.migration.phases.base import Phase, build_post_payload +from unstract.migration.report import MigrationReport, PhaseResult +from unstract.migration.walker import remap_uuids + +logger = logging.getLogger(__name__) + +WORKFLOW_PATH = "workflow/" + + +class WorkflowPhase(Phase): + name = "workflow" + + def run(self, report: MigrationReport) -> PhaseResult: + result = report.get_phase(self.name) + try: + self._writable = self.ctx.target.get_post_schema(WORKFLOW_PATH) + except Exception as e: + logger.exception("Failed to fetch target POST schema for workflow: %s", e) + result.failed += 1 + result.errors.append(f"OPTIONS workflow: {e}") + return result + + try: + src_workflows = self.ctx.source.list_workflows() + except Exception as e: + logger.exception("Failed to list source workflows: %s", e) + result.failed += 1 + result.errors.append(f"list source workflows: {e}") + return result + + logger.info("Found %d workflow(s) in source org", len(src_workflows)) + for src in src_workflows: + self._migrate_one(src, result) + return result + + def _migrate_one(self, src: dict[str, Any], result: PhaseResult) -> None: + name = src["workflow_name"] + src_id = src["id"] + + try: + existing = self.ctx.target.list_workflows(name=name) + except Exception as e: + logger.exception("Failed to GET workflow %s on target: %s", name, e) + result.failed += 1 + result.errors.append(f"GET {name}: {e}") + return + + if existing: + tgt = existing[0] + if self.ctx.options.on_name_conflict == "abort": + raise NameConflictError( + f"workflow '{name}' already exists in target as {tgt['id']}" + ) + result.adopted += 1 + logger.info("adopted workflow '%s' src=%s -> tgt=%s", name, src_id, tgt["id"]) + elif self.ctx.options.dry_run: + result.skipped += 1 + logger.info("[dry-run] would create workflow '%s' src=%s", name, src_id) + return + else: + remapped = remap_uuids(src, self.ctx.remap) + payload = build_post_payload(remapped, self._writable) + try: + tgt = self.ctx.target.create_workflow(payload) + except Exception as e: + logger.exception("Failed to create workflow %s: %s", name, e) + result.failed += 1 + result.errors.append(f"create {name}: {e}") + return + result.created += 1 + logger.info("created workflow '%s' src=%s -> tgt=%s", name, src_id, tgt["id"]) + + self.ctx.remap.record("workflow", src_id, tgt["id"]) diff --git a/tests/migration/test_workflow_phase.py b/tests/migration/test_workflow_phase.py new file mode 100644 index 0000000..4583d75 --- /dev/null +++ b/tests/migration/test_workflow_phase.py @@ -0,0 +1,155 @@ +"""Tests for ``WorkflowPhase``. + +Coverage: +- happy path: source workflow created on target, connector UUIDs in + ``source_settings`` / ``destination_settings`` rewritten via walker. +- idempotency: re-run on existing target adopts and doesn't duplicate. +- dry-run: no POST. +- abort on name conflict. +""" + +from __future__ import annotations + +import pytest + +from unstract.migration.context import ( + MigrationContext, + MigrationOptions, + RemapTable, +) +from unstract.migration.exceptions import NameConflictError +from unstract.migration.phases.workflow import WorkflowPhase +from unstract.migration.report import MigrationReport + + +WORKFLOW_POST_SCHEMA = frozenset( + { + "workflow_name", + "description", + "is_active", + "deployment_type", + "source_settings", + "destination_settings", + "max_file_execution_count", + "shared_users", + "shared_to_org", + } +) + + +class FakeClient: + def __init__(self, workflows: list[dict] | None = None): + self.workflows: list[dict] = list(workflows or []) + self.posts: list[dict] = [] + self._next_id = 1 + + def get_post_schema(self, entity_path: str) -> frozenset[str]: + return WORKFLOW_POST_SCHEMA + + def list_workflows(self, *, name: str | None = None): + result = self.workflows + if name is not None: + result = [w for w in result if w["workflow_name"] == name] + return list(result) + + def create_workflow(self, payload: dict) -> dict: + new = dict(payload) + new["id"] = f"tgt-{self._next_id:08d}-0000-0000-0000-000000000000" + self._next_id += 1 + self.workflows.append(new) + self.posts.append(new) + return new + + +def _src(id_, name, *, source_settings=None, destination_settings=None): + return { + "id": id_, + "workflow_name": name, + "description": f"{name} desc", + "is_active": True, + "deployment_type": "DEFAULT", + "source_settings": source_settings or {}, + "destination_settings": destination_settings or {}, + "max_file_execution_count": None, + "shared_users": [], + "shared_to_org": False, + } + + +def _ctx(source, target, *, remap=None, **opt_overrides): + return MigrationContext( + source=source, + target=target, + options=MigrationOptions(**opt_overrides), + remap=remap or RemapTable(), + ) + + +def test_happy_path_creates_workflow_and_remaps_connector_uuids(): + src_conn = "11111111-1111-1111-1111-111111111111" + tgt_conn = "a1111111-1111-1111-1111-111111111111" + src = FakeClient( + [ + _src( + "wf-src-1", + "Invoice ETL", + source_settings={"connector_id": src_conn, "extras": {"a": 1}}, + destination_settings={"connector_id": src_conn}, + ) + ] + ) + tgt = FakeClient() + remap = RemapTable() + remap.record("connector", src_conn, tgt_conn) + ctx = _ctx(src, tgt, remap=remap) + + result = WorkflowPhase(ctx).run(MigrationReport()) + + assert result.created == 1 + assert result.failed == 0 + assert len(tgt.posts) == 1 + posted = tgt.posts[0] + # Walker rewrote both occurrences of the source connector UUID. + assert posted["source_settings"]["connector_id"] == tgt_conn + assert posted["destination_settings"]["connector_id"] == tgt_conn + # Unrelated nested data passes through untouched. + assert posted["source_settings"]["extras"] == {"a": 1} + + assert ctx.remap.resolve("workflow", "wf-src-1") == posted["id"] + + +def test_idempotent_rerun_adopts_existing_workflow(): + src = FakeClient([_src("wf-src-1", "Invoice ETL")]) + tgt = FakeClient( + [{"id": "wf-tgt-pre", "workflow_name": "Invoice ETL"}] + ) + ctx = _ctx(src, tgt, on_name_conflict="adopt") + + result = WorkflowPhase(ctx).run(MigrationReport()) + + assert result.adopted == 1 + assert result.created == 0 + assert tgt.posts == [] + assert ctx.remap.resolve("workflow", "wf-src-1") == "wf-tgt-pre" + + +def test_dry_run_creates_nothing(): + src = FakeClient([_src("wf-src-1", "Invoice ETL")]) + tgt = FakeClient() + ctx = _ctx(src, tgt, dry_run=True) + + result = WorkflowPhase(ctx).run(MigrationReport()) + + assert result.skipped == 1 + assert tgt.posts == [] + + +def test_abort_on_name_conflict_raises(): + src = FakeClient([_src("wf-src-1", "Invoice ETL")]) + tgt = FakeClient( + [{"id": "wf-tgt-pre", "workflow_name": "Invoice ETL"}] + ) + ctx = _ctx(src, tgt, on_name_conflict="abort") + + with pytest.raises(NameConflictError): + WorkflowPhase(ctx).run(MigrationReport()) From 2a0604cae6ec37b3f2208dce44f4d5c59ec13a7f Mon Sep 17 00:00:00 2001 From: Chandrasekharan M Date: Sun, 24 May 2026 03:07:49 +0530 Subject: [PATCH 08/25] feat(migration): ToolInstance + WorkflowEndpoint phases Closes the workflow execution loop on target. CustomToolPhase now records a prompt_studio_registry remap (src_registry_id -> tgt_registry_id) by looking up both sides via the newly-filterable registry list endpoint. ToolInstancePhase walks the workflow remap, creates one bare instance per target workflow (POST overrides metadata server-side), then PATCHes metadata so source's adapter selections survive. Source metadata stores adapters as names, which match across orgs since AdapterPhase preserves them; the backend resolves names to local UUIDs on PATCH. WorkflowEndpointPhase PATCHes the SOURCE/DESTINATION endpoints that the backend auto-created on workflow POST, pairing by endpoint_type and remapping connector FK + walker-rewriting embedded UUIDs in configuration. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/unstract/migration/client.py | 68 ++++++ src/unstract/migration/orchestrator.py | 4 + src/unstract/migration/phases/__init__.py | 4 + src/unstract/migration/phases/custom_tool.py | 23 ++ .../migration/phases/tool_instance.py | 133 +++++++++++ .../migration/phases/workflow_endpoint.py | 149 ++++++++++++ tests/migration/test_custom_tool_phase.py | 20 ++ tests/migration/test_tool_instance_phase.py | 177 ++++++++++++++ .../migration/test_workflow_endpoint_phase.py | 218 ++++++++++++++++++ 9 files changed, 796 insertions(+) create mode 100644 src/unstract/migration/phases/tool_instance.py create mode 100644 src/unstract/migration/phases/workflow_endpoint.py create mode 100644 tests/migration/test_tool_instance_phase.py create mode 100644 tests/migration/test_workflow_endpoint_phase.py diff --git a/src/unstract/migration/client.py b/src/unstract/migration/client.py index b456e69..ef622b2 100644 --- a/src/unstract/migration/client.py +++ b/src/unstract/migration/client.py @@ -243,3 +243,71 @@ def get_workflow(self, workflow_id: str) -> dict[str, Any]: def create_workflow(self, payload: dict[str, Any]) -> dict[str, Any]: """Create a workflow. Backend auto-creates empty WorkflowEndpoints for it.""" return self._request("POST", "workflow/", json=payload) + + # ----- prompt studio registry ----- + + def list_registries( + self, *, custom_tool: str | None = None + ) -> list[dict[str, Any]]: + """List PromptStudioRegistry rows. The list endpoint returns nothing + unless a filter is supplied; pass ``custom_tool`` to look up the + registry id for a given tool. + """ + params: dict[str, Any] = {} + if custom_tool is not None: + params["custom_tool"] = custom_tool + result = self._request("GET", "prompt-studio/registry/", params=params) + return result if isinstance(result, list) else result.get("results", []) + + # ----- tool instances ----- + + def list_tool_instances( + self, *, workflow_id: str | None = None + ) -> list[dict[str, Any]]: + """List ToolInstance rows, optionally scoped to a workflow.""" + params: dict[str, Any] = {} + if workflow_id is not None: + params["workflow"] = workflow_id + result = self._request("GET", "tool_instance/", params=params) + return result if isinstance(result, list) else result.get("results", []) + + def create_tool_instance(self, payload: dict[str, Any]) -> dict[str, Any]: + """Create a tool instance (max 1 per workflow). The backend overwrites + the ``metadata`` field with tool defaults — caller must PATCH after + create to transfer source metadata. + """ + return self._request("POST", "tool_instance/", json=payload) + + def update_tool_instance_metadata( + self, instance_id: str, metadata: dict[str, Any] + ) -> dict[str, Any]: + """PATCH a tool instance's metadata. Backend resolves adapter names + in the payload to local UUIDs via ``update_instance_metadata``. + """ + return self._request( + "PATCH", f"tool_instance/{instance_id}/", json={"metadata": metadata} + ) + + # ----- workflow endpoints ----- + + def list_workflow_endpoints( + self, *, workflow_id: str | None = None + ) -> list[dict[str, Any]]: + """List workflow endpoints, optionally filtered by workflow id. + + The backend auto-creates one SOURCE and one DESTINATION endpoint + per workflow, so a workflow filter typically returns exactly two + rows. + """ + params: dict[str, Any] = {} + if workflow_id is not None: + params["workflow"] = workflow_id + result = self._request("GET", "workflow/endpoint/", params=params) + return result if isinstance(result, list) else result.get("results", []) + + def update_workflow_endpoint( + self, endpoint_id: str, payload: dict[str, Any] + ) -> dict[str, Any]: + return self._request( + "PATCH", f"workflow/endpoint/{endpoint_id}/", json=payload + ) diff --git a/src/unstract/migration/orchestrator.py b/src/unstract/migration/orchestrator.py index 61773eb..8a58ff2 100644 --- a/src/unstract/migration/orchestrator.py +++ b/src/unstract/migration/orchestrator.py @@ -21,6 +21,8 @@ ConnectorPhase, CustomToolPhase, TagPhase, + ToolInstancePhase, + WorkflowEndpointPhase, WorkflowPhase, ) from unstract.migration.phases.base import Phase @@ -38,6 +40,8 @@ ("tag", TagPhase), ("custom_tool", CustomToolPhase), ("workflow", WorkflowPhase), + ("tool_instance", ToolInstancePhase), + ("workflow_endpoint", WorkflowEndpointPhase), ] diff --git a/src/unstract/migration/phases/__init__.py b/src/unstract/migration/phases/__init__.py index b8b4ab9..125b19b 100644 --- a/src/unstract/migration/phases/__init__.py +++ b/src/unstract/migration/phases/__init__.py @@ -12,7 +12,9 @@ from unstract.migration.phases.connector import ConnectorPhase from unstract.migration.phases.custom_tool import CustomToolPhase from unstract.migration.phases.tag import TagPhase +from unstract.migration.phases.tool_instance import ToolInstancePhase from unstract.migration.phases.workflow import WorkflowPhase +from unstract.migration.phases.workflow_endpoint import WorkflowEndpointPhase __all__ = [ "AdapterPhase", @@ -20,5 +22,7 @@ "CustomToolPhase", "Phase", "TagPhase", + "ToolInstancePhase", + "WorkflowEndpointPhase", "WorkflowPhase", ] diff --git a/src/unstract/migration/phases/custom_tool.py b/src/unstract/migration/phases/custom_tool.py index d32a855..d194ead 100644 --- a/src/unstract/migration/phases/custom_tool.py +++ b/src/unstract/migration/phases/custom_tool.py @@ -191,6 +191,29 @@ def _migrate_one(self, summary: dict[str, Any], result: PhaseResult) -> None: logger.exception("Registry republish failed for tool %s: %s", tool_name, e) result.failed += 1 result.errors.append(f"export {tool_name}: {e}") + return + + # Record the registry-id remap so ToolInstancePhase can rewrite + # ToolInstance.tool_id (which holds a registry UUID as CharField). + # Source-side registry exists only if the operator already published + # the tool; un-published tools have no ToolInstance to migrate. + try: + src_regs = self.ctx.source.list_registries(custom_tool=src_tool_id) + tgt_regs = self.ctx.target.list_registries(custom_tool=tgt_tool_id) + except Exception as e: + logger.warning( + "registry remap lookup failed for tool '%s' (downstream " + "ToolInstance migration may skip): %s", + tool_name, e, + ) + return + + if src_regs and tgt_regs: + src_reg_id = src_regs[0]["prompt_registry_id"] + tgt_reg_id = tgt_regs[0]["prompt_registry_id"] + self.ctx.remap.record( + "prompt_studio_registry", src_reg_id, tgt_reg_id + ) def _get_or_create_tool( self, src_tool: dict[str, Any], result: PhaseResult diff --git a/src/unstract/migration/phases/tool_instance.py b/src/unstract/migration/phases/tool_instance.py new file mode 100644 index 0000000..e4d3cc0 --- /dev/null +++ b/src/unstract/migration/phases/tool_instance.py @@ -0,0 +1,133 @@ +"""Migrate ToolInstance rows from source org to target org. + +Each workflow holds at most one ToolInstance, enforced server-side +(``tool_instance_v2/serializers.py`` raises if a workflow already has one). +The row carries: + +- ``workflow`` FK — remapped from the WorkflowPhase remap table. +- ``tool_id`` (CharField, not FK) — a ``prompt_registry_id`` UUID. The + target's registry was rebuilt in CustomToolPhase, so we remap via the + ``prompt_studio_registry`` table populated there. +- ``metadata`` JSON — backend's ``create()`` discards the POST metadata + and rebuilds it from tool defaults. So we POST a bare instance, then + PATCH the metadata afterwards. Source metadata stores adapter values + as NAMES (via to_representation in source GET); on PATCH the backend's + ``update_metadata_with_adapter_instances`` resolves those names to + the target's adapter UUIDs. Names match across orgs because + AdapterPhase preserved them. +""" + +from __future__ import annotations + +import logging +from typing import Any + +from unstract.migration.phases.base import Phase +from unstract.migration.report import MigrationReport, PhaseResult + +logger = logging.getLogger(__name__) + + +class ToolInstancePhase(Phase): + name = "tool_instance" + + def run(self, report: MigrationReport) -> PhaseResult: + result = report.get_phase(self.name) + workflow_remap = self.ctx.remap.snapshot().get("workflow", {}) + if not workflow_remap: + logger.info("No workflows in remap; nothing to do for tool_instance phase") + return result + + for src_wf_id, tgt_wf_id in workflow_remap.items(): + self._migrate_workflow_tools(src_wf_id, tgt_wf_id, result) + return result + + def _migrate_workflow_tools( + self, src_wf_id: str, tgt_wf_id: str, result: PhaseResult + ) -> None: + try: + src_instances = self.ctx.source.list_tool_instances(workflow_id=src_wf_id) + except Exception as e: + logger.exception("Failed to list source tool_instances for wf %s: %s", src_wf_id, e) + result.failed += 1 + result.errors.append(f"list src tool_instances {src_wf_id}: {e}") + return + + if not src_instances: + return + if len(src_instances) > 1: + # Backend enforces ≤1; warn loudly if invariant breaks on source. + logger.warning( + "source workflow %s has %d tool_instances (expected ≤1) — migrating first only", + src_wf_id, len(src_instances), + ) + + src_ti = src_instances[0] + src_ti_id = src_ti["id"] + src_tool_id = src_ti["tool_id"] + + tgt_tool_id = self.ctx.remap.resolve("prompt_studio_registry", src_tool_id) + if not tgt_tool_id: + logger.warning( + "skipping tool_instance %s — no registry remap for tool_id %s " + "(custom tool likely unpublished on source)", + src_ti_id, src_tool_id, + ) + result.skipped += 1 + return + + try: + existing = self.ctx.target.list_tool_instances(workflow_id=tgt_wf_id) + except Exception as e: + logger.exception("Failed to list target tool_instances for wf %s: %s", tgt_wf_id, e) + result.failed += 1 + result.errors.append(f"list tgt tool_instances {tgt_wf_id}: {e}") + return + + if existing: + tgt_ti = existing[0] + result.adopted += 1 + logger.info( + "adopted tool_instance src=%s -> tgt=%s (workflow %s)", + src_ti_id, tgt_ti["id"], tgt_wf_id, + ) + elif self.ctx.options.dry_run: + result.skipped += 1 + logger.info( + "[dry-run] would create tool_instance for tgt workflow %s " + "(src tool_instance %s)", + tgt_wf_id, src_ti_id, + ) + return + else: + try: + tgt_ti = self.ctx.target.create_tool_instance( + {"workflow_id": tgt_wf_id, "tool_id": tgt_tool_id} + ) + except Exception as e: + logger.exception( + "Failed to create tool_instance for wf %s: %s", tgt_wf_id, e + ) + result.failed += 1 + result.errors.append(f"create tool_instance {tgt_wf_id}: {e}") + return + result.created += 1 + logger.info( + "created tool_instance src=%s -> tgt=%s (workflow %s)", + src_ti_id, tgt_ti["id"], tgt_wf_id, + ) + + # PATCH the metadata regardless of created/adopted — keeps tool config + # aligned with source on every run. + src_metadata = src_ti.get("metadata") or {} + try: + self.ctx.target.update_tool_instance_metadata(tgt_ti["id"], src_metadata) + except Exception as e: + logger.exception( + "Failed to PATCH tool_instance %s metadata: %s", tgt_ti["id"], e + ) + result.failed += 1 + result.errors.append(f"patch metadata {tgt_ti['id']}: {e}") + return + + self.ctx.remap.record("tool_instance", src_ti_id, tgt_ti["id"]) diff --git a/src/unstract/migration/phases/workflow_endpoint.py b/src/unstract/migration/phases/workflow_endpoint.py new file mode 100644 index 0000000..cbdadd3 --- /dev/null +++ b/src/unstract/migration/phases/workflow_endpoint.py @@ -0,0 +1,149 @@ +"""Migrate WorkflowEndpoint rows from source org to target org. + +The backend auto-creates one SOURCE and one DESTINATION endpoint per +workflow on workflow create (``perform_create`` in WorkflowViewSet), so +there's nothing to POST — we only PATCH the target's existing endpoints +with the source's connection_type, connector_instance, and configuration. + +Notes: +- ``workflow`` and ``endpoint_type`` are ``editable=False`` server-side + and aren't writable on PATCH. +- ``connector_instance`` FK is nullable; we remap via the connector + remap table populated in ConnectorPhase. +- ``configuration`` is a JSON blob that may embed connector UUIDs; + walker pass remaps them before PATCH. +- Source ``connector_instance`` arrives as a nested dict (per + ``WorkflowEndpointSerializer.connector_instance``); we extract its + ``id`` and remap. +""" + +from __future__ import annotations + +import logging +from typing import Any + +from unstract.migration.phases.base import Phase +from unstract.migration.report import MigrationReport, PhaseResult +from unstract.migration.walker import remap_uuids + +logger = logging.getLogger(__name__) + + +def _extract_connector_id(endpoint: dict[str, Any]) -> str | None: + """``connector_instance`` is a nested dict on GET; pull out the FK uuid.""" + ci = endpoint.get("connector_instance") + if isinstance(ci, dict): + return ci.get("id") + if isinstance(ci, str): + return ci + return None + + +class WorkflowEndpointPhase(Phase): + name = "workflow_endpoint" + + def run(self, report: MigrationReport) -> PhaseResult: + result = report.get_phase(self.name) + workflow_remap = self.ctx.remap.snapshot().get("workflow", {}) + if not workflow_remap: + logger.info("No workflows in remap; nothing to do for workflow_endpoint phase") + return result + + for src_wf_id, tgt_wf_id in workflow_remap.items(): + self._migrate_workflow_endpoints(src_wf_id, tgt_wf_id, result) + return result + + def _migrate_workflow_endpoints( + self, src_wf_id: str, tgt_wf_id: str, result: PhaseResult + ) -> None: + try: + src_endpoints = self.ctx.source.list_workflow_endpoints( + workflow_id=src_wf_id + ) + except Exception as e: + logger.exception( + "Failed to list source endpoints for wf %s: %s", src_wf_id, e + ) + result.failed += 1 + result.errors.append(f"list src endpoints {src_wf_id}: {e}") + return + + try: + tgt_endpoints = self.ctx.target.list_workflow_endpoints( + workflow_id=tgt_wf_id + ) + except Exception as e: + logger.exception( + "Failed to list target endpoints for wf %s: %s", tgt_wf_id, e + ) + result.failed += 1 + result.errors.append(f"list tgt endpoints {tgt_wf_id}: {e}") + return + + tgt_by_type = {ep["endpoint_type"]: ep for ep in tgt_endpoints} + + for src_ep in src_endpoints: + etype = src_ep["endpoint_type"] + tgt_ep = tgt_by_type.get(etype) + if tgt_ep is None: + # Target should have auto-created this; missing means the + # workflow create flow failed earlier — surface loudly. + logger.warning( + "target workflow %s missing %s endpoint — skipping", + tgt_wf_id, etype, + ) + result.failed += 1 + result.errors.append( + f"missing tgt {etype} endpoint for wf {tgt_wf_id}" + ) + continue + + self._patch_endpoint(src_ep, tgt_ep, result) + + def _patch_endpoint( + self, src_ep: dict[str, Any], tgt_ep: dict[str, Any], result: PhaseResult + ) -> None: + src_ep_id = src_ep["id"] + tgt_ep_id = tgt_ep["id"] + etype = src_ep["endpoint_type"] + + if self.ctx.options.dry_run: + result.skipped += 1 + logger.info( + "[dry-run] would PATCH %s endpoint src=%s -> tgt=%s", + etype, src_ep_id, tgt_ep_id, + ) + return + + src_conn_id = _extract_connector_id(src_ep) + tgt_conn_id: str | None = None + if src_conn_id: + tgt_conn_id = self.ctx.remap.resolve("connector", src_conn_id) + if not tgt_conn_id: + logger.warning( + "no connector remap for %s on %s endpoint %s — leaving unset", + src_conn_id, etype, src_ep_id, + ) + + payload: dict[str, Any] = { + "connection_type": src_ep.get("connection_type") or "", + "configuration": remap_uuids(src_ep.get("configuration") or {}, self.ctx.remap), + "connector_instance_id": tgt_conn_id, + } + + try: + self.ctx.target.update_workflow_endpoint(tgt_ep_id, payload) + except Exception as e: + logger.exception( + "Failed to PATCH %s endpoint tgt=%s: %s", etype, tgt_ep_id, e + ) + result.failed += 1 + result.errors.append(f"patch {etype} {tgt_ep_id}: {e}") + return + + result.created += 1 + logger.info( + "patched %s endpoint src=%s -> tgt=%s (connector %s)", + etype, src_ep_id, tgt_ep_id, tgt_conn_id, + ) + self.ctx.remap.record("workflow_endpoint", src_ep_id, tgt_ep_id) diff --git a/tests/migration/test_custom_tool_phase.py b/tests/migration/test_custom_tool_phase.py index 0b30863..2ee302b 100644 --- a/tests/migration/test_custom_tool_phase.py +++ b/tests/migration/test_custom_tool_phase.py @@ -49,6 +49,7 @@ def __init__(self) -> None: self.tools: dict[str, dict] = {} self.profiles_by_tool: dict[str, list[dict]] = {} self.prompts_by_tool: dict[str, list[dict]] = {} + self.registries_by_tool: dict[str, dict] = {} self.export_calls: list[str] = [] self._next = 1 @@ -88,6 +89,17 @@ def create_custom_tool(self, payload: dict) -> dict: def export_custom_tool(self, tool_id: str, *, force: bool = True) -> None: self.export_calls.append(tool_id) + # Mimic the backend: export creates/updates a registry row for the tool. + self.registries_by_tool.setdefault( + tool_id, + {"prompt_registry_id": self._mint("registry"), "custom_tool": tool_id}, + ) + + def list_registries(self, *, custom_tool: str | None = None) -> list[dict]: + if custom_tool is None: + return list(self.registries_by_tool.values()) + reg = self.registries_by_tool.get(custom_tool) + return [reg] if reg else [] # --- profiles --- def list_profiles(self, tool_id: str) -> list[dict]: @@ -193,6 +205,10 @@ def _preload_source(client: FakeClient, tool_id: str) -> dict: client.tools[tool_id] = tool client.profiles_by_tool[tool_id] = [profile] client.prompts_by_tool[tool_id] = [prompt] + client.registries_by_tool[tool_id] = { + "prompt_registry_id": "55555555-5555-5555-5555-555555555555", + "custom_tool": tool_id, + } return tool @@ -244,6 +260,10 @@ def test_fresh_tool_creates_tool_profiles_prompts_and_republishes(): assert ctx.remap.resolve("custom_tool", "src-tool-x") == tgt_tool_id assert ctx.remap.resolve("profile_manager", "src-profile-1") == profile["profile_id"] assert ctx.remap.resolve("prompt", "src-prompt-1") == prompts[0]["prompt_id"] + # Registry remap recorded for ToolInstancePhase consumption. + assert ctx.remap.resolve( + "prompt_studio_registry", "55555555-5555-5555-5555-555555555555" + ) == tgt.registries_by_tool[tgt_tool_id]["prompt_registry_id"] def test_idempotent_rerun_does_not_create_duplicates(): diff --git a/tests/migration/test_tool_instance_phase.py b/tests/migration/test_tool_instance_phase.py new file mode 100644 index 0000000..c43c311 --- /dev/null +++ b/tests/migration/test_tool_instance_phase.py @@ -0,0 +1,177 @@ +"""Tests for ``ToolInstancePhase``. + +ToolInstance is unique among phases: +- The source list of "things to migrate" comes from the workflow remap + table, not a top-level entity list. +- Create is a two-step dance (POST bare, PATCH metadata) because the + backend rebuilds metadata from defaults on POST. +""" + +from __future__ import annotations + +from unstract.migration.context import ( + MigrationContext, + MigrationOptions, + RemapTable, +) +from unstract.migration.phases.tool_instance import ToolInstancePhase +from unstract.migration.report import MigrationReport + + +class FakeClient: + def __init__(self) -> None: + # Keyed by workflow_id -> list of tool_instances. + self.instances: dict[str, list[dict]] = {} + self.create_calls: list[dict] = [] + self.patch_calls: list[tuple[str, dict]] = [] + self._next = 1 + + def _mint(self) -> str: + s = f"tgt-ti-{self._next:04d}" + self._next += 1 + return s + + def list_tool_instances(self, *, workflow_id: str | None = None) -> list[dict]: + if workflow_id is None: + return [ti for instances in self.instances.values() for ti in instances] + return list(self.instances.get(workflow_id, [])) + + def create_tool_instance(self, payload: dict) -> dict: + wf = payload["workflow_id"] + new = {**payload, "id": self._mint(), "metadata": {"defaults": True}} + self.instances.setdefault(wf, []).append(new) + self.create_calls.append(new) + return new + + def update_tool_instance_metadata( + self, instance_id: str, metadata: dict + ) -> dict: + self.patch_calls.append((instance_id, metadata)) + for wf_instances in self.instances.values(): + for ti in wf_instances: + if ti["id"] == instance_id: + ti["metadata"] = metadata + return ti + raise KeyError(instance_id) + + +def _ctx(source, target, *, remap=None, **opt_overrides): + return MigrationContext( + source=source, + target=target, + options=MigrationOptions(**opt_overrides), + remap=remap or RemapTable(), + ) + + +def _src_ti(ti_id: str, wf_id: str, tool_id: str, metadata: dict) -> dict: + return { + "id": ti_id, + "workflow": wf_id, + "tool_id": tool_id, + "metadata": metadata, + "step": 1, + } + + +SRC_WF = "10000000-0000-0000-0000-000000000001" +TGT_WF = "20000000-0000-0000-0000-000000000001" +SRC_REG = "30000000-0000-0000-0000-000000000001" +TGT_REG = "40000000-0000-0000-0000-000000000001" + + +def _seed_remap() -> RemapTable: + remap = RemapTable() + remap.record("workflow", SRC_WF, TGT_WF) + remap.record("prompt_studio_registry", SRC_REG, TGT_REG) + return remap + + +def test_happy_path_creates_instance_then_patches_metadata(): + src = FakeClient() + src.instances[SRC_WF] = [ + _src_ti( + "src-ti-1", SRC_WF, SRC_REG, + {"llm": "My OpenAI", "embedding": "MyEmb", "tenant_id": "src-org"}, + ) + ] + tgt = FakeClient() + ctx = _ctx(src, tgt, remap=_seed_remap()) + + result = ToolInstancePhase(ctx).run(MigrationReport()) + + assert result.created == 1 + assert result.failed == 0 + assert len(tgt.create_calls) == 1 + posted = tgt.create_calls[0] + assert posted["workflow_id"] == TGT_WF + assert posted["tool_id"] == TGT_REG + # PATCH carries the source metadata verbatim (backend handles name→UUID). + assert len(tgt.patch_calls) == 1 + patched_id, patched_metadata = tgt.patch_calls[0] + assert patched_id == posted["id"] + assert patched_metadata == { + "llm": "My OpenAI", "embedding": "MyEmb", "tenant_id": "src-org", + } + assert ctx.remap.resolve("tool_instance", "src-ti-1") == posted["id"] + + +def test_skip_when_registry_remap_missing(): + src = FakeClient() + src.instances[SRC_WF] = [_src_ti("src-ti-1", SRC_WF, "unknown-reg", {})] + tgt = FakeClient() + remap = RemapTable() + remap.record("workflow", SRC_WF, TGT_WF) + # No prompt_studio_registry remap entry → SDK must skip. + ctx = _ctx(src, tgt, remap=remap) + + result = ToolInstancePhase(ctx).run(MigrationReport()) + + assert result.skipped == 1 + assert result.created == 0 + assert tgt.create_calls == [] + + +def test_adopt_existing_target_instance_and_repatch_metadata(): + src = FakeClient() + src_meta = {"llm": "My OpenAI"} + src.instances[SRC_WF] = [_src_ti("src-ti-1", SRC_WF, SRC_REG, src_meta)] + tgt = FakeClient() + tgt.instances[TGT_WF] = [ + {"id": "tgt-pre-ti", "workflow": TGT_WF, "tool_id": TGT_REG, "metadata": {}} + ] + ctx = _ctx(src, tgt, remap=_seed_remap()) + + result = ToolInstancePhase(ctx).run(MigrationReport()) + + assert result.adopted == 1 + assert result.created == 0 + assert tgt.create_calls == [] + # PATCH still fires for the adopted instance to align metadata. + assert tgt.patch_calls == [("tgt-pre-ti", src_meta)] + assert ctx.remap.resolve("tool_instance", "src-ti-1") == "tgt-pre-ti" + + +def test_no_op_when_no_workflows_in_remap(): + src = FakeClient() + tgt = FakeClient() + ctx = _ctx(src, tgt, remap=RemapTable()) + + result = ToolInstancePhase(ctx).run(MigrationReport()) + + assert result.created == 0 + assert result.skipped == 0 + assert tgt.create_calls == [] + + +def test_dry_run_does_not_create_or_patch(): + src = FakeClient() + src.instances[SRC_WF] = [_src_ti("src-ti-1", SRC_WF, SRC_REG, {"x": 1})] + tgt = FakeClient() + ctx = _ctx(src, tgt, remap=_seed_remap(), dry_run=True) + + result = ToolInstancePhase(ctx).run(MigrationReport()) + + assert result.skipped == 1 + assert tgt.create_calls == [] + assert tgt.patch_calls == [] diff --git a/tests/migration/test_workflow_endpoint_phase.py b/tests/migration/test_workflow_endpoint_phase.py new file mode 100644 index 0000000..2f21f52 --- /dev/null +++ b/tests/migration/test_workflow_endpoint_phase.py @@ -0,0 +1,218 @@ +"""Tests for ``WorkflowEndpointPhase``. + +WorkflowEndpoints are PATCH-only — backend auto-creates them on workflow +POST. Tests verify that the SDK: +- pairs source/target endpoints by ``endpoint_type``; +- remaps the embedded ``connector_instance`` UUID; +- walker-rewrites UUIDs nested in ``configuration``; +- silently leaves connector_instance_id null when no remap exists. +""" + +from __future__ import annotations + +from unstract.migration.context import ( + MigrationContext, + MigrationOptions, + RemapTable, +) +from unstract.migration.phases.workflow_endpoint import WorkflowEndpointPhase +from unstract.migration.report import MigrationReport + + +class FakeClient: + def __init__(self) -> None: + self.endpoints: dict[str, list[dict]] = {} + self.patch_calls: list[tuple[str, dict]] = [] + + def list_workflow_endpoints( + self, *, workflow_id: str | None = None + ) -> list[dict]: + if workflow_id is None: + return [ep for eps in self.endpoints.values() for ep in eps] + return list(self.endpoints.get(workflow_id, [])) + + def update_workflow_endpoint( + self, endpoint_id: str, payload: dict + ) -> dict: + self.patch_calls.append((endpoint_id, payload)) + for eps in self.endpoints.values(): + for ep in eps: + if ep["id"] == endpoint_id: + ep.update(payload) + return ep + raise KeyError(endpoint_id) + + +def _ctx(source, target, *, remap=None, **opt_overrides): + return MigrationContext( + source=source, + target=target, + options=MigrationOptions(**opt_overrides), + remap=remap or RemapTable(), + ) + + +SRC_WF = "10000000-0000-0000-0000-000000000001" +TGT_WF = "20000000-0000-0000-0000-000000000001" +SRC_CONN = "30000000-0000-0000-0000-000000000001" +TGT_CONN = "40000000-0000-0000-0000-000000000001" + + +def _src_endpoint(ep_id, etype, connector_id, configuration): + return { + "id": ep_id, + "workflow": SRC_WF, + "endpoint_type": etype, + "connection_type": "FILESYSTEM", + "configuration": configuration, + "connector_instance": {"id": connector_id, "connector_name": "src-conn"}, + } + + +def _tgt_endpoint(ep_id, etype): + return { + "id": ep_id, + "workflow": TGT_WF, + "endpoint_type": etype, + "connection_type": "", + "configuration": {}, + "connector_instance": None, + } + + +def _seed_remap() -> RemapTable: + remap = RemapTable() + remap.record("workflow", SRC_WF, TGT_WF) + remap.record("connector", SRC_CONN, TGT_CONN) + return remap + + +def test_pairs_endpoints_by_type_and_remaps_connector(): + src = FakeClient() + src.endpoints[SRC_WF] = [ + _src_endpoint( + "src-ep-source", + "SOURCE", + SRC_CONN, + {"connector_id": SRC_CONN, "path": "/in"}, + ), + _src_endpoint( + "src-ep-dest", + "DESTINATION", + SRC_CONN, + {"connector_id": SRC_CONN, "path": "/out"}, + ), + ] + tgt = FakeClient() + tgt.endpoints[TGT_WF] = [ + _tgt_endpoint("tgt-ep-source", "SOURCE"), + _tgt_endpoint("tgt-ep-dest", "DESTINATION"), + ] + ctx = _ctx(src, tgt, remap=_seed_remap()) + + result = WorkflowEndpointPhase(ctx).run(MigrationReport()) + + assert result.created == 2 + assert result.failed == 0 + assert len(tgt.patch_calls) == 2 + + patches_by_id = dict(tgt.patch_calls) + src_patch = patches_by_id["tgt-ep-source"] + assert src_patch["connection_type"] == "FILESYSTEM" + assert src_patch["connector_instance_id"] == TGT_CONN + assert src_patch["configuration"]["connector_id"] == TGT_CONN + assert src_patch["configuration"]["path"] == "/in" + + dst_patch = patches_by_id["tgt-ep-dest"] + assert dst_patch["configuration"]["path"] == "/out" + assert dst_patch["connector_instance_id"] == TGT_CONN + + assert ctx.remap.resolve("workflow_endpoint", "src-ep-source") == "tgt-ep-source" + assert ctx.remap.resolve("workflow_endpoint", "src-ep-dest") == "tgt-ep-dest" + + +def test_endpoint_without_source_connector_patches_with_null(): + src = FakeClient() + src.endpoints[SRC_WF] = [ + { + "id": "src-ep-source", + "endpoint_type": "SOURCE", + "connection_type": "API", + "configuration": {"foo": "bar"}, + "connector_instance": None, + } + ] + tgt = FakeClient() + tgt.endpoints[TGT_WF] = [_tgt_endpoint("tgt-ep-source", "SOURCE")] + ctx = _ctx(src, tgt, remap=_seed_remap()) + + result = WorkflowEndpointPhase(ctx).run(MigrationReport()) + + assert result.created == 1 + assert len(tgt.patch_calls) == 1 + _, payload = tgt.patch_calls[0] + assert payload["connector_instance_id"] is None + assert payload["configuration"] == {"foo": "bar"} + + +def test_unknown_connector_uuid_logs_but_does_not_fail(): + src = FakeClient() + src.endpoints[SRC_WF] = [ + _src_endpoint( + "src-ep-source", + "SOURCE", + "unmapped-but-uuid-99999999-9999-9999-9999-999999999999"[:36], + {}, + ) + ] + tgt = FakeClient() + tgt.endpoints[TGT_WF] = [_tgt_endpoint("tgt-ep-source", "SOURCE")] + ctx = _ctx(src, tgt, remap=_seed_remap()) + + result = WorkflowEndpointPhase(ctx).run(MigrationReport()) + + assert result.created == 1 + # No remap → connector_instance_id stays None instead of failing the PATCH. + _, payload = tgt.patch_calls[0] + assert payload["connector_instance_id"] is None + + +def test_missing_target_endpoint_fails_loudly(): + src = FakeClient() + src.endpoints[SRC_WF] = [ + _src_endpoint("src-ep-source", "SOURCE", SRC_CONN, {}) + ] + tgt = FakeClient() + tgt.endpoints[TGT_WF] = [] # No endpoints — anomaly. + ctx = _ctx(src, tgt, remap=_seed_remap()) + + result = WorkflowEndpointPhase(ctx).run(MigrationReport()) + + assert result.failed == 1 + assert tgt.patch_calls == [] + + +def test_dry_run_makes_no_patches(): + src = FakeClient() + src.endpoints[SRC_WF] = [ + _src_endpoint("src-ep-source", "SOURCE", SRC_CONN, {}) + ] + tgt = FakeClient() + tgt.endpoints[TGT_WF] = [_tgt_endpoint("tgt-ep-source", "SOURCE")] + ctx = _ctx(src, tgt, remap=_seed_remap(), dry_run=True) + + result = WorkflowEndpointPhase(ctx).run(MigrationReport()) + + assert result.skipped == 1 + assert tgt.patch_calls == [] + + +def test_no_workflows_in_remap_is_noop(): + src = FakeClient() + tgt = FakeClient() + ctx = _ctx(src, tgt, remap=RemapTable()) + + result = WorkflowEndpointPhase(ctx).run(MigrationReport()) + + assert result.created == 0 + assert tgt.patch_calls == [] From f879a8fb3add35d861636efdca2a87ffba59af4d Mon Sep 17 00:00:00 2001 From: Chandrasekharan M Date: Sun, 24 May 2026 03:39:19 +0530 Subject: [PATCH 09/25] refactor(migration): use project-transfer for CustomToolPhase MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Drop the field-by-field reconcile loop (create_custom_tool + delete_profile + create_profile + set_default_profile + create_prompt) in favour of the backend's purpose-built endpoints: - GET prompt-studio/project-transfer/{id} bundles tool_metadata, tool_settings, default_profile_settings, prompts in one shot. - POST prompt-studio/project-transfer/ creates the tool, default profile (wired with target-org adapter ids the SDK supplies), and prompts server-side in one call. - POST prompt-studio/{id}/sync-prompts/ rip-and-replaces prompts on an existing target tool for the adopt path. Adapter ids for the import are resolved from the source's default ProfileManager via the adapter remap table; missing remap fails the tool cleanly instead of landing a half-wired profile. Removes hardcoded PROFILE_WRITABLE / PROMPT_WRITABLE frozensets and the OPTIONS schema fetch for prompt-studio — the project-transfer endpoint owns the field shape server-side. Co-Authored-By: Claude Opus 4.7 --- src/unstract/migration/client.py | 132 ++++-- src/unstract/migration/phases/custom_tool.py | 415 +++++++---------- tests/migration/test_custom_tool_phase.py | 451 ++++++++----------- 3 files changed, 456 insertions(+), 542 deletions(-) diff --git a/src/unstract/migration/client.py b/src/unstract/migration/client.py index ef622b2..99187e0 100644 --- a/src/unstract/migration/client.py +++ b/src/unstract/migration/client.py @@ -10,6 +10,7 @@ from __future__ import annotations +import json as json_lib import logging from typing import Any @@ -54,6 +55,8 @@ def _request( *, params: dict[str, Any] | None = None, json: Any = None, + files: dict[str, Any] | None = None, + data: dict[str, Any] | None = None, ) -> Any: url = self._url(path) # Redact secrets from logs: only entity path + method, never body. @@ -63,6 +66,8 @@ def _request( url, params=params, json=json, + files=files, + data=data, timeout=self.timeout, verify=self.verify, ) @@ -164,67 +169,104 @@ def list_custom_tools(self) -> list[dict[str, Any]]: result = self._request("GET", "prompt-studio/") return result if isinstance(result, list) else result.get("results", []) - def get_custom_tool(self, tool_id: str) -> dict[str, Any]: - """Tool detail; response includes embedded ``prompts`` + ``default_profile``.""" - return self._request("GET", f"prompt-studio/{tool_id}/") - - def create_custom_tool(self, payload: dict[str, Any]) -> dict[str, Any]: - """Create a custom tool. Backend also auto-creates one default ProfileManager.""" - return self._request("POST", "prompt-studio/", json=payload) - - def export_custom_tool(self, tool_id: str, *, force: bool = True) -> Any: - """Republish ``PromptStudioRegistry`` from the tool's current target state. + def list_profiles(self, tool_id: str) -> list[dict[str, Any]]: + """List ProfileManager rows for a tool. - Used after profile+prompt reconciliation so the registry row is - rebuilt without the SDK ever carrying ``tool_metadata`` across orgs. + Migration reads this on the source only — to discover the + default profile's adapter UUIDs so they can be remapped to + target adapter ids for ``import_project``. """ - return self._request( - "POST", - f"prompt-studio/export/{tool_id}", - json={ - "is_shared_with_org": False, - "user_id": [], - "force_export": force, - }, - ) - - # ----- profile managers ----- - - def list_profiles(self, tool_id: str) -> list[dict[str, Any]]: - """List ProfileManager rows for a tool via the per-tool list action.""" result = self._request( "GET", f"prompt-studio/prompt-studio-profile/{tool_id}/" ) return result if isinstance(result, list) else result.get("results", []) - def create_profile(self, tool_id: str, payload: dict[str, Any]) -> dict[str, Any]: - """POST to ``prompt-studio/profilemanager/{tool_id}`` (no trailing slash).""" + def export_project(self, tool_id: str) -> dict[str, Any]: + """Export a prompt-studio project as a portable JSON blob. + + Bundles ``tool_metadata``, ``tool_settings``, + ``default_profile_settings``, ``prompts``, ``export_metadata`` in + one shot — feed straight into ``import_project`` or + ``sync_prompts`` on the target. + """ + return self._request("GET", f"prompt-studio/project-transfer/{tool_id}") + + def import_project( + self, + export_data: dict[str, Any], + adapter_ids: dict[str, str | None] | None = None, + ) -> dict[str, Any]: + """Import a prompt-studio project from an export blob. + + Backend creates the tool, builds the default ProfileManager from + the supplied target-org adapter ids, and imports all prompts in + one call. On name collision the backend silently uniquifies the + new tool's name — callers should pre-check via + ``list_custom_tools`` to avoid that. + + ``adapter_ids`` keys are the backend's form fields: + ``llm_adapter_id``, ``vector_db_adapter_id``, + ``embedding_adapter_id``, ``x2text_adapter_id``. All four + required to wire the profile; otherwise backend falls back to + a profile without adapters and flags ``needs_adapter_config``. + """ + tool_name = ( + export_data.get("tool_metadata", {}).get("tool_name") or "export" + ) + content = json_lib.dumps(export_data).encode() + files = {"file": (f"{tool_name}.json", content, "application/json")} + data: dict[str, Any] = {} + if adapter_ids: + for key in ( + "llm_adapter_id", + "vector_db_adapter_id", + "embedding_adapter_id", + "x2text_adapter_id", + ): + val = adapter_ids.get(key) + if val: + data[key] = val return self._request( - "POST", f"prompt-studio/profilemanager/{tool_id}", json=payload + "POST", + "prompt-studio/project-transfer/", + files=files, + data=data, ) - def delete_profile(self, profile_id: str) -> None: - self._request("DELETE", f"profile-manager/{profile_id}/") + def sync_prompts( + self, + tool_id: str, + export_data: dict[str, Any], + *, + create_copy: bool = False, + ) -> dict[str, Any]: + """Rip-and-replace prompts on an existing target tool. - def set_default_profile(self, tool_id: str, profile_id: str) -> Any: - """Mark a single profile as default for this tool (zeros the rest).""" + Adopt path: target tool already exists with its own + adapter-bound profiles. This overwrites its prompt set (and + ``tool_settings``) from source; profiles and uploaded documents + are left untouched. + """ + payload = {"data": export_data, "create_copy": create_copy} return self._request( - "PATCH", - f"prompt-studio/prompt-studio-profile/{tool_id}/", - json={"default_profile": profile_id}, + "POST", f"prompt-studio/{tool_id}/sync-prompts/", json=payload ) - # ----- prompts ----- - - def list_prompts(self, *, tool_id: str) -> list[dict[str, Any]]: - """List prompts filtered by tool_id (FilterHelper-backed).""" - result = self._request("GET", "prompt/", params={"tool_id": tool_id}) - return result if isinstance(result, list) else result.get("results", []) + def export_custom_tool(self, tool_id: str, *, force: bool = True) -> Any: + """Republish ``PromptStudioRegistry`` from the tool's current state. - def create_prompt(self, tool_id: str, payload: dict[str, Any]) -> dict[str, Any]: - """POST to ``prompt-studio/prompt-studio-prompt/{tool_id}/`` (create_prompt action).""" + Called after import/sync so the registry row reflects the + freshly landed prompts. Required for ToolInstancePhase to find + a target registry id to remap. + """ return self._request( - "POST", f"prompt-studio/prompt-studio-prompt/{tool_id}/", json=payload + "POST", + f"prompt-studio/export/{tool_id}", + json={ + "is_shared_with_org": False, + "user_id": [], + "force_export": force, + }, ) # ----- workflows ----- diff --git a/src/unstract/migration/phases/custom_tool.py b/src/unstract/migration/phases/custom_tool.py index d194ead..58873d6 100644 --- a/src/unstract/migration/phases/custom_tool.py +++ b/src/unstract/migration/phases/custom_tool.py @@ -1,31 +1,28 @@ -"""Migrate prompt-studio projects (CustomTool) and their children. - -Composite phase: a single project carries ``ProfileManager`` rows (LLM -triad config) and ``ToolStudioPrompt`` rows (the actual prompts). All -three must land together for the project to be functional on target, so -they live in one phase rather than three sibling phases. - -Within a project, the create order is: - - 1. CustomTool — POST creates the project and auto-creates one default - ProfileManager on target. - 2. ProfileManagers — on a freshly-created tool we delete the auto-default - first so the source's profiles land cleanly. On an adopted tool we - reconcile by ``profile_name`` (per-tool unique). - 3. ToolStudioPrompts — reconcile by ``prompt_key`` (per-tool unique). - 4. Republish PromptStudioRegistry via the ``export-tool`` action so the - registry row is rebuilt server-side from the now-correct child state. - Avoids the SDK carrying ``tool_metadata`` JSON across orgs. - -Walker remapping: adapter UUIDs embedded in the tool's adapter FKs -(``monitor_llm``, ``challenge_llm``, ``summarize_llm_adapter``), in the -profile's adapter FKs (``llm``, ``embedding_model``, ``vector_store``, -``x2text``), and in the prompt's ``profile_manager`` + ``tool_id`` FKs -are remapped before POST using the running ``RemapTable``. - -The ProfileManager GET response expands adapter FKs into nested adapter -dicts (per the backend serializer's ``to_representation``); we flatten -them back to UUIDs before walker pass. +"""Migrate prompt-studio projects via the project-transfer endpoints. + +For each source tool the phase: + +1. ``GET prompt-studio/project-transfer/{src_tool_id}`` — pulls a + portable JSON blob (tool_metadata, tool_settings, + default_profile_settings, prompts, export_metadata). +2. Decides fresh vs adopt by looking up the target tool by name. +3. **Fresh path**: reads source's default ProfileManager to learn the + adapter UUIDs the profile is bound to, remaps each via the running + ``adapter`` remap table, and POSTs the import as a multipart upload + with target-org adapter ids on the form. Backend creates the tool, + the default profile, and all prompts server-side in one call. +4. **Adopt path**: POSTs ``sync-prompts`` on the existing target tool. + Backend rip-and-replaces prompts + ``tool_settings`` and leaves the + target's locally-configured profiles + adapters untouched (which is + what the operator wants — they may have rewired adapters on target). +5. Republishes ``PromptStudioRegistry`` via the export action and + records the ``custom_tool`` + ``prompt_studio_registry`` remaps so + downstream ToolInstancePhase can rewrite ``ToolInstance.tool_id``. + +Adapter id discovery for the fresh path needs all four of LLM, +vector_db, embedding, x2text. If any source adapter can't be resolved +via the adapter remap, the tool is failed cleanly — we never want to +land a half-wired profile. """ from __future__ import annotations @@ -34,77 +31,28 @@ from typing import Any from unstract.migration.exceptions import NameConflictError -from unstract.migration.phases.base import SERVER_MANAGED, Phase, build_post_payload +from unstract.migration.phases.base import Phase from unstract.migration.report import MigrationReport, PhaseResult -from unstract.migration.walker import remap_uuids logger = logging.getLogger(__name__) -TOOL_PATH = "prompt-studio/" - -# Per-action endpoints on PromptStudioCoreView don't surface their own -# DRF metadata (OPTIONS returns the parent CustomToolSerializer schema). -# Hardcode the model-derived writable subset for the children and let the -# integration test catch backend drift. -PROFILE_WRITABLE: frozenset[str] = frozenset( - { - "profile_name", - "vector_store", - "embedding_model", - "llm", - "x2text", - "chunk_size", - "chunk_overlap", - "reindex", - "retrieval_strategy", - "similarity_top_k", - "section", - "prompt_studio_tool", - "is_default", - "is_summarize_llm", - } +_PROFILE_ADAPTER_FIELDS: tuple[tuple[str, str], ...] = ( + ("llm", "llm_adapter_id"), + ("vector_store", "vector_db_adapter_id"), + ("embedding_model", "embedding_adapter_id"), + ("x2text", "x2text_adapter_id"), ) -PROMPT_WRITABLE: frozenset[str] = frozenset( - { - "prompt_key", - "enforce_type", - "prompt", - "tool_id", - "sequence_number", - "prompt_type", - "profile_manager", - "output", - "assert_prompt", - "assertion_failure_prompt", - "required", - "is_assert", - "active", - "output_metadata", - "postprocessing_webhook_url", - "evaluate", - "eval_quality_faithfulness", - "eval_quality_correctness", - "eval_quality_relevance", - "eval_security_pii", - "eval_guidance_toxicity", - "eval_guidance_completeness", - } -) - -_PROFILE_ADAPTER_KEYS = ("llm", "embedding_model", "vector_store", "x2text") - -def _flatten_profile_adapters(profile: dict[str, Any]) -> dict[str, Any]: - """ProfileManagerSerializer.to_representation expands FK adapters into - nested dicts; for write paths we need flat UUIDs back. +def _extract_adapter_id(value: Any) -> str | None: + """Profile FKs come back as nested dicts via serializer expansion; + pull the UUID back out for either flat-string or nested-dict shapes. """ - out = dict(profile) - for key in _PROFILE_ADAPTER_KEYS: - val = out.get(key) - if isinstance(val, dict) and "id" in val: - out[key] = val["id"] - return out + if isinstance(value, dict): + return value.get("id") + if isinstance(value, str): + return value + return None class CustomToolPhase(Phase): @@ -112,14 +60,6 @@ class CustomToolPhase(Phase): def run(self, report: MigrationReport) -> PhaseResult: result = report.get_phase(self.name) - try: - self._tool_writable = self.ctx.target.get_post_schema(TOOL_PATH) - except Exception as e: - logger.exception("Failed to fetch target POST schema for prompt-studio: %s", e) - result.failed += 1 - result.errors.append(f"OPTIONS prompt-studio: {e}") - return result - try: src_tools = self.ctx.source.list_custom_tools() except Exception as e: @@ -138,52 +78,36 @@ def _migrate_one(self, summary: dict[str, Any], result: PhaseResult) -> None: src_tool_id = summary["tool_id"] try: - src_tool = self.ctx.source.get_custom_tool(src_tool_id) + export_data = self.ctx.source.export_project(src_tool_id) except Exception as e: - logger.exception("Failed to GET source tool %s: %s", tool_name, e) + logger.exception("Failed to export source tool '%s': %s", tool_name, e) result.failed += 1 - result.errors.append(f"GET source tool {tool_name}: {e}") - return - - tgt_tool, fresh = self._get_or_create_tool(src_tool, result) - if tgt_tool is None: - return - - tgt_tool_id = tgt_tool["tool_id"] - self.ctx.remap.record("custom_tool", src_tool_id, tgt_tool_id) - - if self.ctx.options.dry_run: - logger.info( - "[dry-run] would reconcile profiles+prompts for tool '%s' src=%s", - tool_name, src_tool_id, - ) + result.errors.append(f"export src tool {tool_name}: {e}") return try: - src_profiles = self.ctx.source.list_profiles(src_tool_id) + target_tools = self.ctx.target.list_custom_tools() except Exception as e: - logger.exception("Failed to list source profiles for %s: %s", tool_name, e) + logger.exception("Failed to list target tools: %s", e) result.failed += 1 - result.errors.append(f"list src profiles {tool_name}: {e}") + result.errors.append(f"list target tools: {e}") return + match = next( + (t for t in target_tools if t["tool_name"] == tool_name), None + ) - try: - self._reconcile_profiles(src_profiles, tgt_tool_id, fresh) - except Exception as e: - logger.exception("Profile reconcile failed for tool %s: %s", tool_name, e) - result.failed += 1 - result.errors.append(f"profiles {tool_name}: {e}") - return + if match is not None: + tgt_tool_id = self._adopt(match, export_data, result, tool_name, src_tool_id) + else: + tgt_tool_id = self._create_fresh( + export_data, src_tool_id, tool_name, result + ) - try: - src_prompts = src_tool.get("prompts") or [] - self._reconcile_prompts(src_prompts, tgt_tool_id) - except Exception as e: - logger.exception("Prompt reconcile failed for tool %s: %s", tool_name, e) - result.failed += 1 - result.errors.append(f"prompts {tool_name}: {e}") + if tgt_tool_id is None: return + self.ctx.remap.record("custom_tool", src_tool_id, tgt_tool_id) + try: self.ctx.target.export_custom_tool(tgt_tool_id) logger.info("republished registry for tool '%s' tgt=%s", tool_name, tgt_tool_id) @@ -193,147 +117,146 @@ def _migrate_one(self, summary: dict[str, Any], result: PhaseResult) -> None: result.errors.append(f"export {tool_name}: {e}") return - # Record the registry-id remap so ToolInstancePhase can rewrite - # ToolInstance.tool_id (which holds a registry UUID as CharField). - # Source-side registry exists only if the operator already published - # the tool; un-published tools have no ToolInstance to migrate. + # Record registry remap so ToolInstancePhase can rewrite + # ToolInstance.tool_id (which stores a registry UUID as CharField). + # Source registry exists only if the operator already published + # the tool there; unpublished source tools simply produce no + # ToolInstance rows for downstream to remap. try: src_regs = self.ctx.source.list_registries(custom_tool=src_tool_id) tgt_regs = self.ctx.target.list_registries(custom_tool=tgt_tool_id) except Exception as e: logger.warning( - "registry remap lookup failed for tool '%s' (downstream " - "ToolInstance migration may skip): %s", + "registry remap lookup failed for tool '%s' " + "(downstream ToolInstance migration may skip): %s", tool_name, e, ) return if src_regs and tgt_regs: - src_reg_id = src_regs[0]["prompt_registry_id"] - tgt_reg_id = tgt_regs[0]["prompt_registry_id"] self.ctx.remap.record( - "prompt_studio_registry", src_reg_id, tgt_reg_id + "prompt_studio_registry", + src_regs[0]["prompt_registry_id"], + tgt_regs[0]["prompt_registry_id"], + ) + + def _adopt( + self, + match: dict[str, Any], + export_data: dict[str, Any], + result: PhaseResult, + tool_name: str, + src_tool_id: str, + ) -> str | None: + if self.ctx.options.on_name_conflict == "abort": + raise NameConflictError( + f"tool '{tool_name}' already exists in target as {match['tool_id']}" ) - def _get_or_create_tool( - self, src_tool: dict[str, Any], result: PhaseResult - ) -> tuple[dict[str, Any] | None, bool]: - tool_name = src_tool["tool_name"] - src_tool_id = src_tool["tool_id"] + tgt_tool_id = match["tool_id"] + if self.ctx.options.dry_run: + result.skipped += 1 + logger.info( + "[dry-run] would sync prompts into adopted tool '%s' src=%s -> tgt=%s", + tool_name, src_tool_id, tgt_tool_id, + ) + return tgt_tool_id try: - target_tools = self.ctx.target.list_custom_tools() + self.ctx.target.sync_prompts(tgt_tool_id, export_data) except Exception as e: - logger.exception("Failed to list target tools: %s", e) + logger.exception("sync_prompts failed for tool %s: %s", tool_name, e) result.failed += 1 - result.errors.append(f"list target tools: {e}") - return None, False + result.errors.append(f"sync {tool_name}: {e}") + return None - match = next((t for t in target_tools if t["tool_name"] == tool_name), None) - if match is not None: - if self.ctx.options.on_name_conflict == "abort": - raise NameConflictError( - f"tool '{tool_name}' already exists in target as {match['tool_id']}" - ) - result.adopted += 1 - logger.info( - "adopted tool '%s' src=%s -> tgt=%s", - tool_name, src_tool_id, match["tool_id"], - ) - return match, False + result.adopted += 1 + logger.info( + "adopted tool '%s' src=%s -> tgt=%s (prompts re-synced)", + tool_name, src_tool_id, tgt_tool_id, + ) + return tgt_tool_id + def _create_fresh( + self, + export_data: dict[str, Any], + src_tool_id: str, + tool_name: str, + result: PhaseResult, + ) -> str | None: if self.ctx.options.dry_run: result.skipped += 1 - logger.info("[dry-run] would create tool '%s' src=%s", tool_name, src_tool_id) - return None, True + logger.info( + "[dry-run] would import tool '%s' src=%s", tool_name, src_tool_id + ) + return None + + adapter_ids = self._resolve_target_adapter_ids(src_tool_id, tool_name) + if adapter_ids is None: + result.failed += 1 + result.errors.append( + f"import {tool_name}: missing target adapter remap for default profile" + ) + return None - remapped = remap_uuids(src_tool, self.ctx.remap) - payload = build_post_payload(remapped, self._tool_writable) try: - tgt = self.ctx.target.create_custom_tool(payload) + tgt = self.ctx.target.import_project(export_data, adapter_ids=adapter_ids) except Exception as e: - logger.exception("Failed to create tool %s: %s", tool_name, e) + logger.exception("import_project failed for tool %s: %s", tool_name, e) result.failed += 1 - result.errors.append(f"create tool {tool_name}: {e}") - return None, True + result.errors.append(f"import {tool_name}: {e}") + return None + + tgt_tool_id = tgt["tool_id"] result.created += 1 logger.info( - "created tool '%s' src=%s -> tgt=%s", - tool_name, src_tool_id, tgt["tool_id"], + "created tool '%s' src=%s -> tgt=%s (needs_adapter_config=%s)", + tool_name, src_tool_id, tgt_tool_id, tgt.get("needs_adapter_config"), ) - return tgt, True + return tgt_tool_id - def _reconcile_profiles( - self, - src_profiles: list[dict[str, Any]], - tgt_tool_id: str, - fresh: bool, - ) -> None: - if fresh: - for p in self.ctx.target.list_profiles(tgt_tool_id): - self.ctx.target.delete_profile(p["profile_id"]) - logger.debug("deleted auto-default profile %s", p["profile_id"]) - - src_default_id: str | None = None - for src_profile in src_profiles: - src_pid = src_profile["profile_id"] - if src_profile.get("is_default"): - src_default_id = src_pid - - target_profiles_by_name = { - p["profile_name"]: p - for p in self.ctx.target.list_profiles(tgt_tool_id) - } - existing = target_profiles_by_name.get(src_profile["profile_name"]) - - if existing is not None: - tgt_pid = existing["profile_id"] - logger.info( - "adopted profile '%s' src=%s -> tgt=%s", - src_profile["profile_name"], src_pid, tgt_pid, - ) - else: - flat = _flatten_profile_adapters(src_profile) - remapped = remap_uuids(flat, self.ctx.remap) - remapped["prompt_studio_tool"] = tgt_tool_id - payload = build_post_payload(remapped, PROFILE_WRITABLE) - tgt = self.ctx.target.create_profile(tgt_tool_id, payload) - tgt_pid = tgt["profile_id"] - logger.info( - "created profile '%s' src=%s -> tgt=%s", - src_profile["profile_name"], src_pid, tgt_pid, - ) - self.ctx.remap.record("profile_manager", src_pid, tgt_pid) - - if src_default_id: - tgt_default = self.ctx.remap.resolve("profile_manager", src_default_id) - if tgt_default: - self.ctx.target.set_default_profile(tgt_tool_id, tgt_default) - - def _reconcile_prompts( - self, src_prompts: list[dict[str, Any]], tgt_tool_id: str - ) -> None: - existing_prompts = self.ctx.target.list_prompts(tool_id=tgt_tool_id) - by_key = {p["prompt_key"]: p for p in existing_prompts} - - for src_prompt in src_prompts: - src_prompt_id = src_prompt["prompt_id"] - key = src_prompt["prompt_key"] - existing = by_key.get(key) - if existing is not None: - tgt_pid = existing["prompt_id"] - logger.info( - "adopted prompt '%s' src=%s -> tgt=%s", - key, src_prompt_id, tgt_pid, + def _resolve_target_adapter_ids( + self, src_tool_id: str, tool_name: str + ) -> dict[str, str] | None: + """Read source default profile → remap each adapter UUID to target. + + Returns ``None`` if any of the four required adapters can't be + resolved via the ``adapter`` remap — caller fails the tool. + """ + try: + src_profiles = self.ctx.source.list_profiles(src_tool_id) + except Exception as e: + logger.exception( + "Failed to list source profiles for tool %s: %s", tool_name, e + ) + return None + + default = next( + (p for p in src_profiles if p.get("is_default")), + src_profiles[0] if src_profiles else None, + ) + if default is None: + logger.warning( + "source tool '%s' has no profiles to derive adapter ids from", + tool_name, + ) + return None + + resolved: dict[str, str] = {} + for src_field, form_field in _PROFILE_ADAPTER_FIELDS: + src_adapter_id = _extract_adapter_id(default.get(src_field)) + if not src_adapter_id: + logger.warning( + "source default profile for tool '%s' missing adapter '%s'", + tool_name, src_field, ) - else: - remapped = remap_uuids(src_prompt, self.ctx.remap) - remapped["tool_id"] = tgt_tool_id - payload = build_post_payload(remapped, PROMPT_WRITABLE - SERVER_MANAGED) - tgt = self.ctx.target.create_prompt(tgt_tool_id, payload) - tgt_pid = tgt["prompt_id"] - logger.info( - "created prompt '%s' src=%s -> tgt=%s", - key, src_prompt_id, tgt_pid, + return None + tgt_adapter_id = self.ctx.remap.resolve("adapter", src_adapter_id) + if not tgt_adapter_id: + logger.warning( + "no adapter remap for %s (field %s) on tool '%s'", + src_adapter_id, src_field, tool_name, ) - self.ctx.remap.record("prompt", src_prompt_id, tgt_pid) + return None + resolved[form_field] = tgt_adapter_id + return resolved diff --git a/tests/migration/test_custom_tool_phase.py b/tests/migration/test_custom_tool_phase.py index 2ee302b..b9ddfef 100644 --- a/tests/migration/test_custom_tool_phase.py +++ b/tests/migration/test_custom_tool_phase.py @@ -1,99 +1,73 @@ -"""Tests for ``CustomToolPhase`` — the composite tool + profile + prompt phase. +"""Tests for ``CustomToolPhase`` — project-transfer + sync-prompts based. Coverage: -- happy path: fresh tool → tool created, auto-default deleted, source - profiles + prompts created, registry republished. -- idempotency: re-run is a no-op when tool + profile names + prompt keys - already match on target. -- adopt path: existing tool on target, partial overlap of profiles — - matching ones adopted, missing ones created. -- dry-run: nothing posted. -- adapter UUID remap into profile FKs. +- fresh path: ``export_project`` on source → ``import_project`` on + target with adapter ids resolved from source's default profile and + remapped via the adapter table. +- adopt path: existing target tool with matching name → + ``sync_prompts`` overwrites prompts; no profile/adapter writes. +- registry remap recorded after ``export_custom_tool``. +- dry-run: no writes on either side. +- abort on name conflict when option is set. +- missing adapter remap fails the tool cleanly. """ from __future__ import annotations +import pytest + from unstract.migration.context import ( MigrationContext, MigrationOptions, RemapTable, ) +from unstract.migration.exceptions import NameConflictError from unstract.migration.phases.custom_tool import CustomToolPhase from unstract.migration.report import MigrationReport -TOOL_POST_SCHEMA = frozenset( - { - "tool_name", - "description", - "author", - "icon", - "preamble", - "postamble", - "prompt_grammer", - "monitor_llm", - "challenge_llm", - "summarize_llm_adapter", - "custom_data", - "single_pass_extraction_mode", - "shared_users", - "shared_to_org", - } -) +SRC_LLM = "11111111-1111-1111-1111-111111111111" +SRC_EMB = "22222222-2222-2222-2222-222222222222" +SRC_VEC = "33333333-3333-3333-3333-333333333333" +SRC_X2T = "44444444-4444-4444-4444-444444444444" +TGT_LLM = "a1111111-1111-1111-1111-111111111111" +TGT_EMB = "a2222222-2222-2222-2222-222222222222" +TGT_VEC = "a3333333-3333-3333-3333-333333333333" +TGT_X2T = "a4444444-4444-4444-4444-444444444444" +SRC_REG = "55555555-5555-5555-5555-555555555555" class FakeClient: - """In-memory stand-in for ``PlatformClient`` covering the prompt-studio surface.""" + """In-memory stand-in for ``PlatformClient`` covering project-transfer.""" def __init__(self) -> None: self.tools: dict[str, dict] = {} self.profiles_by_tool: dict[str, list[dict]] = {} - self.prompts_by_tool: dict[str, list[dict]] = {} + self.export_blobs: dict[str, dict] = {} self.registries_by_tool: dict[str, dict] = {} - self.export_calls: list[str] = [] + # Call recorders. + self.import_calls: list[tuple[dict, dict | None]] = [] + self.sync_calls: list[tuple[str, dict, bool]] = [] + self.export_tool_calls: list[str] = [] self._next = 1 - # --- ID helper --- def _mint(self, prefix: str) -> str: s = f"tgt-{prefix}-{self._next:04d}" self._next += 1 return s - # --- schema --- - def get_post_schema(self, entity_path: str) -> frozenset[str]: - if entity_path == "prompt-studio/": - return TOOL_POST_SCHEMA - raise AssertionError(f"unexpected OPTIONS path: {entity_path}") - - # --- tools --- + # --- reads --- def list_custom_tools(self) -> list[dict]: - return list(self.tools.values()) + return [ + {"tool_id": tid, "tool_name": t["tool_name"]} + for tid, t in self.tools.items() + ] - def get_custom_tool(self, tool_id: str) -> dict: - return self.tools[tool_id] - - def create_custom_tool(self, payload: dict) -> dict: - tool_id = self._mint("tool") - tool = {**payload, "tool_id": tool_id, "prompts": []} - self.tools[tool_id] = tool - # Backend auto-creates a default profile on create. - auto = { - "profile_id": self._mint("autoprofile"), - "profile_name": "Default", - "is_default": True, - "prompt_studio_tool": tool_id, - } - self.profiles_by_tool[tool_id] = [auto] - self.prompts_by_tool[tool_id] = [] - return tool + def list_profiles(self, tool_id: str) -> list[dict]: + return list(self.profiles_by_tool.get(tool_id, [])) - def export_custom_tool(self, tool_id: str, *, force: bool = True) -> None: - self.export_calls.append(tool_id) - # Mimic the backend: export creates/updates a registry row for the tool. - self.registries_by_tool.setdefault( - tool_id, - {"prompt_registry_id": self._mint("registry"), "custom_tool": tool_id}, - ) + def export_project(self, tool_id: str) -> dict: + return self.export_blobs[tool_id] def list_registries(self, *, custom_tool: str | None = None) -> list[dict]: if custom_tool is None: @@ -101,33 +75,36 @@ def list_registries(self, *, custom_tool: str | None = None) -> list[dict]: reg = self.registries_by_tool.get(custom_tool) return [reg] if reg else [] - # --- profiles --- - def list_profiles(self, tool_id: str) -> list[dict]: - return list(self.profiles_by_tool.get(tool_id, [])) - - def create_profile(self, tool_id: str, payload: dict) -> dict: - new = {**payload, "profile_id": self._mint("profile")} - self.profiles_by_tool.setdefault(tool_id, []).append(new) - return new - - def delete_profile(self, profile_id: str) -> None: - for tid, profiles in self.profiles_by_tool.items(): - self.profiles_by_tool[tid] = [ - p for p in profiles if p["profile_id"] != profile_id - ] - - def set_default_profile(self, tool_id: str, profile_id: str) -> None: - for p in self.profiles_by_tool.get(tool_id, []): - p["is_default"] = p["profile_id"] == profile_id + # --- writes --- + def import_project( + self, export_data: dict, adapter_ids: dict | None = None + ) -> dict: + self.import_calls.append((export_data, adapter_ids)) + tool_id = self._mint("tool") + tool_name = export_data["tool_metadata"]["tool_name"] + self.tools[tool_id] = {"tool_name": tool_name} + return { + "tool_id": tool_id, + "message": f"Project imported successfully as '{tool_name}'", + "needs_adapter_config": adapter_ids is None, + } - # --- prompts --- - def list_prompts(self, *, tool_id: str) -> list[dict]: - return list(self.prompts_by_tool.get(tool_id, [])) + def sync_prompts( + self, tool_id: str, export_data: dict, *, create_copy: bool = False + ) -> dict: + self.sync_calls.append((tool_id, export_data, create_copy)) + return { + "prompts_created": len(export_data.get("prompts", [])), + "prompts_deleted": 0, + "tool_settings_updated": True, + } - def create_prompt(self, tool_id: str, payload: dict) -> dict: - new = {**payload, "prompt_id": self._mint("prompt")} - self.prompts_by_tool.setdefault(tool_id, []).append(new) - return new + def export_custom_tool(self, tool_id: str, *, force: bool = True) -> None: + self.export_tool_calls.append(tool_id) + self.registries_by_tool.setdefault( + tool_id, + {"prompt_registry_id": self._mint("registry"), "custom_tool": tool_id}, + ) def _ctx(source, target, *, remap=None, **opt_overrides) -> MigrationContext: @@ -139,224 +116,196 @@ def _ctx(source, target, *, remap=None, **opt_overrides) -> MigrationContext: ) -def _src_tool(tool_id: str, name: str, prompts: list[dict] | None = None) -> dict: - return { - "tool_id": tool_id, - "tool_name": name, - "description": f"{name} desc", - "author": "src", - "icon": "", - "preamble": "", - "postamble": "", - "prompt_grammer": {}, - "monitor_llm": None, - "challenge_llm": None, - "summarize_llm_adapter": None, - "custom_data": {}, - "single_pass_extraction_mode": False, - "shared_users": [], - "shared_to_org": False, - "prompts": prompts or [], - } +def _seed_adapter_remap(remap: RemapTable) -> None: + remap.record("adapter", SRC_LLM, TGT_LLM) + remap.record("adapter", SRC_EMB, TGT_EMB) + remap.record("adapter", SRC_VEC, TGT_VEC) + remap.record("adapter", SRC_X2T, TGT_X2T) + +def _src_default_profile(*, nested: bool = True) -> dict: + """Mimic ProfileManager serializer output. -def _src_profile(pid: str, name: str, *, is_default: bool = False) -> dict: + ``nested=True`` matches ``to_representation`` expanding FK adapters + into nested dicts; ``nested=False`` covers the raw-UUID fallback. + """ + if nested: + return { + "profile_id": "src-profile-1", + "profile_name": "Default", + "is_default": True, + "llm": {"id": SRC_LLM, "adapter_name": "L"}, + "embedding_model": {"id": SRC_EMB, "adapter_name": "E"}, + "vector_store": {"id": SRC_VEC, "adapter_name": "V"}, + "x2text": {"id": SRC_X2T, "adapter_name": "X"}, + } return { - "profile_id": pid, - "profile_name": name, - "is_default": is_default, - "is_summarize_llm": False, - # Mimic to_representation expansion: nested adapter dicts. - "llm": {"id": "11111111-1111-1111-1111-111111111111", "adapter_name": "L"}, - "embedding_model": {"id": "22222222-2222-2222-2222-222222222222", "adapter_name": "E"}, - "vector_store": {"id": "33333333-3333-3333-3333-333333333333", "adapter_name": "V"}, - "x2text": {"id": "44444444-4444-4444-4444-444444444444", "adapter_name": "X"}, - "chunk_size": 1024, - "chunk_overlap": 128, - "reindex": False, - "retrieval_strategy": "simple", - "similarity_top_k": 3, - "section": "default", - "prompt_studio_tool": None, + "profile_id": "src-profile-1", + "profile_name": "Default", + "is_default": True, + "llm": SRC_LLM, + "embedding_model": SRC_EMB, + "vector_store": SRC_VEC, + "x2text": SRC_X2T, } -def _src_prompt(prompt_id: str, key: str, profile_id: str) -> dict: +def _src_export_blob(tool_name: str) -> dict: return { - "prompt_id": prompt_id, - "prompt_key": key, - "prompt": f"What is {key}?", - "enforce_type": "string", - "prompt_type": "prompt", - "sequence_number": 1, - "tool_id": "src-tool-x", - "profile_manager": profile_id, - "output": "", - "active": True, - "required": False, + "tool_metadata": {"tool_name": tool_name, "description": "x", "author": "a", "icon": None}, + "tool_settings": {"preamble": "p", "postamble": "q"}, + "default_profile_settings": { + "chunk_size": 1024, + "chunk_overlap": 128, + "retrieval_strategy": "simple", + "similarity_top_k": 3, + "section": "default", + "profile_name": "Default", + }, + "prompts": [ + {"prompt_key": "field_a", "prompt": "What is field_a?", "sequence_number": 1} + ], + "export_metadata": {"exported_at": "2026-05-24T00:00:00Z"}, } -def _preload_source(client: FakeClient, tool_id: str) -> dict: - """Helper to set up a source FakeClient with one tool, one profile, one prompt.""" - profile = _src_profile("src-profile-1", "Default", is_default=True) - prompt = _src_prompt("src-prompt-1", "field_a", "src-profile-1") - tool = _src_tool(tool_id, "Invoice Extractor", prompts=[prompt]) - client.tools[tool_id] = tool - client.profiles_by_tool[tool_id] = [profile] - client.prompts_by_tool[tool_id] = [prompt] +def _preload_source_tool( + client: FakeClient, tool_id: str, tool_name: str, *, nested_profile: bool = True +) -> None: + client.tools[tool_id] = {"tool_name": tool_name} + client.profiles_by_tool[tool_id] = [_src_default_profile(nested=nested_profile)] + client.export_blobs[tool_id] = _src_export_blob(tool_name) client.registries_by_tool[tool_id] = { - "prompt_registry_id": "55555555-5555-5555-5555-555555555555", + "prompt_registry_id": SRC_REG, "custom_tool": tool_id, } - return tool - - -def _preload_remap_with_adapters(remap: RemapTable) -> None: - remap.record("adapter", "11111111-1111-1111-1111-111111111111", "a1111111-1111-1111-1111-111111111111") - remap.record("adapter", "22222222-2222-2222-2222-222222222222", "a2222222-2222-2222-2222-222222222222") - remap.record("adapter", "33333333-3333-3333-3333-333333333333", "a3333333-3333-3333-3333-333333333333") - remap.record("adapter", "44444444-4444-4444-4444-444444444444", "a4444444-4444-4444-4444-444444444444") -def test_fresh_tool_creates_tool_profiles_prompts_and_republishes(): +def test_fresh_imports_with_remapped_adapter_ids_and_records_registry(): src = FakeClient() tgt = FakeClient() - _preload_source(src, "src-tool-x") + _preload_source_tool(src, "src-tool-x", "Invoice Extractor") remap = RemapTable() - _preload_remap_with_adapters(remap) + _seed_adapter_remap(remap) ctx = _ctx(src, tgt, remap=remap) - report = MigrationReport() - result = CustomToolPhase(ctx).run(report) + result = CustomToolPhase(ctx).run(MigrationReport()) assert result.created == 1 assert result.failed == 0 - assert len(tgt.tools) == 1 - tgt_tool_id = next(iter(tgt.tools)) - - # Auto-default profile deleted; exactly one profile (the source's) remains. - profiles = tgt.profiles_by_tool[tgt_tool_id] - assert len(profiles) == 1 - profile = profiles[0] - assert profile["profile_name"] == "Default" - assert profile["is_default"] is True - # Adapter FKs remapped via walker. - assert profile["llm"] == "a1111111-1111-1111-1111-111111111111" - assert profile["embedding_model"] == "a2222222-2222-2222-2222-222222222222" - assert profile["vector_store"] == "a3333333-3333-3333-3333-333333333333" - assert profile["x2text"] == "a4444444-4444-4444-4444-444444444444" - - # One prompt landed, pointing at the new tool. - prompts = tgt.prompts_by_tool[tgt_tool_id] - assert len(prompts) == 1 - assert prompts[0]["prompt_key"] == "field_a" - assert prompts[0]["tool_id"] == tgt_tool_id - - # Registry republished exactly once. - assert tgt.export_calls == [tgt_tool_id] + # Exactly one import_project call with the right export blob + remapped adapter ids. + assert len(tgt.import_calls) == 1 + blob, adapter_ids = tgt.import_calls[0] + assert blob["tool_metadata"]["tool_name"] == "Invoice Extractor" + assert adapter_ids == { + "llm_adapter_id": TGT_LLM, + "vector_db_adapter_id": TGT_VEC, + "embedding_adapter_id": TGT_EMB, + "x2text_adapter_id": TGT_X2T, + } + # No sync_prompts on fresh path. + assert tgt.sync_calls == [] + # Registry republish fired exactly once. + assert len(tgt.export_tool_calls) == 1 + tgt_tool_id = tgt.export_tool_calls[0] # Remap records populated for downstream phases. assert ctx.remap.resolve("custom_tool", "src-tool-x") == tgt_tool_id - assert ctx.remap.resolve("profile_manager", "src-profile-1") == profile["profile_id"] - assert ctx.remap.resolve("prompt", "src-prompt-1") == prompts[0]["prompt_id"] - # Registry remap recorded for ToolInstancePhase consumption. - assert ctx.remap.resolve( - "prompt_studio_registry", "55555555-5555-5555-5555-555555555555" - ) == tgt.registries_by_tool[tgt_tool_id]["prompt_registry_id"] + tgt_reg_id = tgt.registries_by_tool[tgt_tool_id]["prompt_registry_id"] + assert ctx.remap.resolve("prompt_studio_registry", SRC_REG) == tgt_reg_id -def test_idempotent_rerun_does_not_create_duplicates(): +def test_flat_uuid_profile_also_resolves_adapter_ids(): src = FakeClient() tgt = FakeClient() - _preload_source(src, "src-tool-x") + _preload_source_tool(src, "src-tool-x", "T", nested_profile=False) remap = RemapTable() - _preload_remap_with_adapters(remap) + _seed_adapter_remap(remap) ctx = _ctx(src, tgt, remap=remap) CustomToolPhase(ctx).run(MigrationReport()) - tgt_tool_id = next(iter(tgt.tools)) - profile_count = len(tgt.profiles_by_tool[tgt_tool_id]) - prompt_count = len(tgt.prompts_by_tool[tgt_tool_id]) - export_count = len(tgt.export_calls) - report2 = MigrationReport() - result2 = CustomToolPhase(ctx).run(report2) + _, adapter_ids = tgt.import_calls[0] + assert adapter_ids["llm_adapter_id"] == TGT_LLM - assert result2.adopted == 1 - assert result2.created == 0 - assert len(tgt.profiles_by_tool[tgt_tool_id]) == profile_count - assert len(tgt.prompts_by_tool[tgt_tool_id]) == prompt_count - # Republish still fires (rebuild registry idempotently). - assert len(tgt.export_calls) == export_count + 1 - -def test_adopt_path_fills_missing_profile_only(): +def test_adopt_path_calls_sync_prompts_and_skips_import(): src = FakeClient() tgt = FakeClient() - # Source has TWO profiles. - extra = _src_profile("src-profile-2", "HighRecall", is_default=False) - default = _src_profile("src-profile-1", "Default", is_default=True) - prompt = _src_prompt("src-prompt-1", "field_a", "src-profile-1") - tool = _src_tool("src-tool-x", "Invoice Extractor", prompts=[prompt]) - src.tools["src-tool-x"] = tool - src.profiles_by_tool["src-tool-x"] = [default, extra] - src.prompts_by_tool["src-tool-x"] = [prompt] - - # Target already has the tool + the "Default" profile + the prompt. - tgt_tool_id = "tgt-pre-tool" - tgt.tools[tgt_tool_id] = { - "tool_id": tgt_tool_id, - "tool_name": "Invoice Extractor", - "prompts": [], - } - tgt.profiles_by_tool[tgt_tool_id] = [ - { - "profile_id": "tgt-pre-profile", - "profile_name": "Default", - "is_default": True, - "prompt_studio_tool": tgt_tool_id, - } - ] - tgt.prompts_by_tool[tgt_tool_id] = [ - { - "prompt_id": "tgt-pre-prompt", - "prompt_key": "field_a", - "tool_id": tgt_tool_id, - } - ] + _preload_source_tool(src, "src-tool-x", "Invoice Extractor") + # Target already has the tool with the same name. + tgt.tools["tgt-existing"] = {"tool_name": "Invoice Extractor"} remap = RemapTable() - _preload_remap_with_adapters(remap) + _seed_adapter_remap(remap) ctx = _ctx(src, tgt, remap=remap) result = CustomToolPhase(ctx).run(MigrationReport()) assert result.adopted == 1 - # Adopted tool path: only the missing "HighRecall" profile got created. - profiles = tgt.profiles_by_tool[tgt_tool_id] - assert len(profiles) == 2 - names = {p["profile_name"] for p in profiles} - assert names == {"Default", "HighRecall"} + assert result.created == 0 + # sync_prompts ran against the pre-existing target tool, not a new one. + assert len(tgt.sync_calls) == 1 + tool_id, blob, create_copy = tgt.sync_calls[0] + assert tool_id == "tgt-existing" + assert blob["tool_metadata"]["tool_name"] == "Invoice Extractor" + assert create_copy is False + # Import path never fired on adopt. + assert tgt.import_calls == [] + # Registry still republished against the adopted tool. + assert tgt.export_tool_calls == ["tgt-existing"] + assert ctx.remap.resolve("custom_tool", "src-tool-x") == "tgt-existing" + + +def test_abort_on_name_conflict_raises(): + src = FakeClient() + tgt = FakeClient() + _preload_source_tool(src, "src-tool-x", "Conflict") + tgt.tools["tgt-existing"] = {"tool_name": "Conflict"} - # Prompt was already there → adopted, not duplicated. - assert len(tgt.prompts_by_tool[tgt_tool_id]) == 1 + remap = RemapTable() + _seed_adapter_remap(remap) + ctx = _ctx(src, tgt, remap=remap, on_name_conflict="abort") - assert ctx.remap.resolve("profile_manager", "src-profile-1") == "tgt-pre-profile" - assert ctx.remap.resolve("prompt", "src-prompt-1") == "tgt-pre-prompt" + with pytest.raises(NameConflictError): + CustomToolPhase(ctx).run(MigrationReport()) + assert tgt.sync_calls == [] + assert tgt.import_calls == [] -def test_dry_run_creates_nothing(): + +def test_dry_run_makes_no_writes(): src = FakeClient() tgt = FakeClient() - _preload_source(src, "src-tool-x") + _preload_source_tool(src, "src-tool-x", "T") remap = RemapTable() - _preload_remap_with_adapters(remap) + _seed_adapter_remap(remap) ctx = _ctx(src, tgt, remap=remap, dry_run=True) result = CustomToolPhase(ctx).run(MigrationReport()) assert result.skipped == 1 - assert tgt.tools == {} - assert tgt.profiles_by_tool == {} - assert tgt.export_calls == [] + assert tgt.import_calls == [] + assert tgt.sync_calls == [] + assert tgt.export_tool_calls == [] + + +def test_missing_adapter_remap_fails_tool_cleanly(): + src = FakeClient() + tgt = FakeClient() + _preload_source_tool(src, "src-tool-x", "T") + # Only seed 3 of 4 adapters → x2text remap missing. + remap = RemapTable() + remap.record("adapter", SRC_LLM, TGT_LLM) + remap.record("adapter", SRC_EMB, TGT_EMB) + remap.record("adapter", SRC_VEC, TGT_VEC) + ctx = _ctx(src, tgt, remap=remap) + + result = CustomToolPhase(ctx).run(MigrationReport()) + + assert result.failed == 1 + assert tgt.import_calls == [] + # Registry republish should NOT fire when the tool fails. + assert tgt.export_tool_calls == [] + # No custom_tool remap recorded. + assert ctx.remap.resolve("custom_tool", "src-tool-x") is None From 60a8f29fc45cb55a30ac57b886be750165e302f8 Mon Sep 17 00:00:00 2001 From: Chandrasekharan M Date: Sun, 24 May 2026 04:00:41 +0530 Subject: [PATCH 10/25] fix(migration): resolve profile adapters by NAME, not UUID MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Live ProfileManagerSerializer renders adapter FKs as flat NAME strings (e.g. "gpt4-o"), not UUIDs. The first project-transfer smoke flagged "no adapter remap for gpt4-o" for every tool — the phase was looking up a name in the UUID-keyed adapter remap. Switch resolution to target-side lookup: read the adapter NAME from the source default profile and ask target's list_adapters(name=...) for the target UUID. AdapterPhase preserves names across orgs, so the lookup hits whenever adapters migrated cleanly. Smoke-test result against local stack: 5/6 source tools created on fresh run, all 5 adopted on re-run (sync_prompts), 1 source tool failed cleanly because it has no profile (test data, not a code bug). Co-Authored-By: Claude Opus 4.7 --- src/unstract/migration/phases/custom_tool.py | 41 +++--- tests/migration/test_custom_tool_phase.py | 131 ++++++++++--------- 2 files changed, 96 insertions(+), 76 deletions(-) diff --git a/src/unstract/migration/phases/custom_tool.py b/src/unstract/migration/phases/custom_tool.py index 58873d6..5831b52 100644 --- a/src/unstract/migration/phases/custom_tool.py +++ b/src/unstract/migration/phases/custom_tool.py @@ -44,14 +44,15 @@ ) -def _extract_adapter_id(value: Any) -> str | None: - """Profile FKs come back as nested dicts via serializer expansion; - pull the UUID back out for either flat-string or nested-dict shapes. +def _extract_adapter_name(value: Any) -> str | None: + """ProfileManagerSerializer.to_representation renders adapter FKs as + flat strings holding the adapter NAME (not the UUID). Tolerate the + nested-dict shape too in case serializer behavior diverges. """ - if isinstance(value, dict): - return value.get("id") if isinstance(value, str): - return value + return value or None + if isinstance(value, dict): + return value.get("adapter_name") or value.get("name") or value.get("id") return None @@ -218,10 +219,13 @@ def _create_fresh( def _resolve_target_adapter_ids( self, src_tool_id: str, tool_name: str ) -> dict[str, str] | None: - """Read source default profile → remap each adapter UUID to target. + """Source profile carries adapter NAMES (per serializer); resolve + each name to a target adapter UUID via ``list_adapters(name=...)``. Returns ``None`` if any of the four required adapters can't be - resolved via the ``adapter`` remap — caller fails the tool. + found on target — caller fails the tool. AdapterPhase preserves + names across orgs so this lookup should always hit when the + adapter migration ran cleanly. """ try: src_profiles = self.ctx.source.list_profiles(src_tool_id) @@ -244,19 +248,26 @@ def _resolve_target_adapter_ids( resolved: dict[str, str] = {} for src_field, form_field in _PROFILE_ADAPTER_FIELDS: - src_adapter_id = _extract_adapter_id(default.get(src_field)) - if not src_adapter_id: + adapter_name = _extract_adapter_name(default.get(src_field)) + if not adapter_name: logger.warning( "source default profile for tool '%s' missing adapter '%s'", tool_name, src_field, ) return None - tgt_adapter_id = self.ctx.remap.resolve("adapter", src_adapter_id) - if not tgt_adapter_id: + try: + matches = self.ctx.target.list_adapters(name=adapter_name) + except Exception as e: + logger.exception( + "list_adapters lookup failed for %s on tool '%s': %s", + adapter_name, tool_name, e, + ) + return None + if not matches: logger.warning( - "no adapter remap for %s (field %s) on tool '%s'", - src_adapter_id, src_field, tool_name, + "no target adapter named '%s' for field %s on tool '%s'", + adapter_name, src_field, tool_name, ) return None - resolved[form_field] = tgt_adapter_id + resolved[form_field] = matches[0]["id"] return resolved diff --git a/tests/migration/test_custom_tool_phase.py b/tests/migration/test_custom_tool_phase.py index b9ddfef..ec0e8cc 100644 --- a/tests/migration/test_custom_tool_phase.py +++ b/tests/migration/test_custom_tool_phase.py @@ -2,14 +2,14 @@ Coverage: - fresh path: ``export_project`` on source → ``import_project`` on - target with adapter ids resolved from source's default profile and - remapped via the adapter table. + target with adapter ids resolved by looking up each source-profile + adapter NAME against the target via ``list_adapters(name=...)``. - adopt path: existing target tool with matching name → ``sync_prompts`` overwrites prompts; no profile/adapter writes. - registry remap recorded after ``export_custom_tool``. - dry-run: no writes on either side. - abort on name conflict when option is set. -- missing adapter remap fails the tool cleanly. +- missing target adapter fails the tool cleanly. """ from __future__ import annotations @@ -26,14 +26,18 @@ from unstract.migration.report import MigrationReport -SRC_LLM = "11111111-1111-1111-1111-111111111111" -SRC_EMB = "22222222-2222-2222-2222-222222222222" -SRC_VEC = "33333333-3333-3333-3333-333333333333" -SRC_X2T = "44444444-4444-4444-4444-444444444444" -TGT_LLM = "a1111111-1111-1111-1111-111111111111" -TGT_EMB = "a2222222-2222-2222-2222-222222222222" -TGT_VEC = "a3333333-3333-3333-3333-333333333333" -TGT_X2T = "a4444444-4444-4444-4444-444444444444" +ADAPTER_NAMES = { + "llm": "gpt4", + "embedding_model": "ada-embed", + "vector_store": "pgvector", + "x2text": "llmw", +} +TGT_ADAPTER_IDS = { + "gpt4": "a1111111-1111-1111-1111-111111111111", + "ada-embed": "a2222222-2222-2222-2222-222222222222", + "pgvector": "a3333333-3333-3333-3333-333333333333", + "llmw": "a4444444-4444-4444-4444-444444444444", +} SRC_REG = "55555555-5555-5555-5555-555555555555" @@ -45,6 +49,7 @@ def __init__(self) -> None: self.profiles_by_tool: dict[str, list[dict]] = {} self.export_blobs: dict[str, dict] = {} self.registries_by_tool: dict[str, dict] = {} + self.adapters_by_name: dict[str, dict] = {} # Call recorders. self.import_calls: list[tuple[dict, dict | None]] = [] self.sync_calls: list[tuple[str, dict, bool]] = [] @@ -69,6 +74,17 @@ def list_profiles(self, tool_id: str) -> list[dict]: def export_project(self, tool_id: str) -> dict: return self.export_blobs[tool_id] + def list_adapters( + self, + *, + name: str | None = None, + adapter_type: str | None = None, + ) -> list[dict]: + if name is None: + return list(self.adapters_by_name.values()) + ad = self.adapters_by_name.get(name) + return [ad] if ad else [] + def list_registries(self, *, custom_tool: str | None = None) -> list[dict]: if custom_tool is None: return list(self.registries_by_tool.values()) @@ -116,37 +132,37 @@ def _ctx(source, target, *, remap=None, **opt_overrides) -> MigrationContext: ) -def _seed_adapter_remap(remap: RemapTable) -> None: - remap.record("adapter", SRC_LLM, TGT_LLM) - remap.record("adapter", SRC_EMB, TGT_EMB) - remap.record("adapter", SRC_VEC, TGT_VEC) - remap.record("adapter", SRC_X2T, TGT_X2T) - +def _seed_target_adapters(target: FakeClient) -> None: + """ProfileManagerSerializer surfaces adapter NAMES — target must + expose name → id lookups for the phase to resolve them. + """ + for name, adapter_id in TGT_ADAPTER_IDS.items(): + target.adapters_by_name[name] = {"id": adapter_id, "adapter_name": name} -def _src_default_profile(*, nested: bool = True) -> dict: - """Mimic ProfileManager serializer output. - ``nested=True`` matches ``to_representation`` expanding FK adapters - into nested dicts; ``nested=False`` covers the raw-UUID fallback. +def _src_default_profile(*, nested: bool = False) -> dict: + """Mirror the live ProfileManager serializer: adapter FKs render as + flat NAME strings. ``nested=True`` covers the alternate dict shape + in case backend behavior changes. """ if nested: return { "profile_id": "src-profile-1", "profile_name": "Default", "is_default": True, - "llm": {"id": SRC_LLM, "adapter_name": "L"}, - "embedding_model": {"id": SRC_EMB, "adapter_name": "E"}, - "vector_store": {"id": SRC_VEC, "adapter_name": "V"}, - "x2text": {"id": SRC_X2T, "adapter_name": "X"}, + "llm": {"adapter_name": ADAPTER_NAMES["llm"]}, + "embedding_model": {"adapter_name": ADAPTER_NAMES["embedding_model"]}, + "vector_store": {"adapter_name": ADAPTER_NAMES["vector_store"]}, + "x2text": {"adapter_name": ADAPTER_NAMES["x2text"]}, } return { "profile_id": "src-profile-1", "profile_name": "Default", "is_default": True, - "llm": SRC_LLM, - "embedding_model": SRC_EMB, - "vector_store": SRC_VEC, - "x2text": SRC_X2T, + "llm": ADAPTER_NAMES["llm"], + "embedding_model": ADAPTER_NAMES["embedding_model"], + "vector_store": ADAPTER_NAMES["vector_store"], + "x2text": ADAPTER_NAMES["x2text"], } @@ -170,7 +186,7 @@ def _src_export_blob(tool_name: str) -> dict: def _preload_source_tool( - client: FakeClient, tool_id: str, tool_name: str, *, nested_profile: bool = True + client: FakeClient, tool_id: str, tool_name: str, *, nested_profile: bool = False ) -> None: client.tools[tool_id] = {"tool_name": tool_name} client.profiles_by_tool[tool_id] = [_src_default_profile(nested=nested_profile)] @@ -181,27 +197,26 @@ def _preload_source_tool( } -def test_fresh_imports_with_remapped_adapter_ids_and_records_registry(): +def test_fresh_imports_with_name_resolved_adapter_ids_and_records_registry(): src = FakeClient() tgt = FakeClient() _preload_source_tool(src, "src-tool-x", "Invoice Extractor") - remap = RemapTable() - _seed_adapter_remap(remap) - ctx = _ctx(src, tgt, remap=remap) + _seed_target_adapters(tgt) + ctx = _ctx(src, tgt) result = CustomToolPhase(ctx).run(MigrationReport()) assert result.created == 1 assert result.failed == 0 - # Exactly one import_project call with the right export blob + remapped adapter ids. + # Exactly one import_project call with the right export blob + name-resolved adapter ids. assert len(tgt.import_calls) == 1 blob, adapter_ids = tgt.import_calls[0] assert blob["tool_metadata"]["tool_name"] == "Invoice Extractor" assert adapter_ids == { - "llm_adapter_id": TGT_LLM, - "vector_db_adapter_id": TGT_VEC, - "embedding_adapter_id": TGT_EMB, - "x2text_adapter_id": TGT_X2T, + "llm_adapter_id": TGT_ADAPTER_IDS["gpt4"], + "vector_db_adapter_id": TGT_ADAPTER_IDS["pgvector"], + "embedding_adapter_id": TGT_ADAPTER_IDS["ada-embed"], + "x2text_adapter_id": TGT_ADAPTER_IDS["llmw"], } # No sync_prompts on fresh path. assert tgt.sync_calls == [] @@ -215,18 +230,17 @@ def test_fresh_imports_with_remapped_adapter_ids_and_records_registry(): assert ctx.remap.resolve("prompt_studio_registry", SRC_REG) == tgt_reg_id -def test_flat_uuid_profile_also_resolves_adapter_ids(): +def test_nested_adapter_dict_also_resolves(): src = FakeClient() tgt = FakeClient() - _preload_source_tool(src, "src-tool-x", "T", nested_profile=False) - remap = RemapTable() - _seed_adapter_remap(remap) - ctx = _ctx(src, tgt, remap=remap) + _preload_source_tool(src, "src-tool-x", "T", nested_profile=True) + _seed_target_adapters(tgt) + ctx = _ctx(src, tgt) CustomToolPhase(ctx).run(MigrationReport()) _, adapter_ids = tgt.import_calls[0] - assert adapter_ids["llm_adapter_id"] == TGT_LLM + assert adapter_ids["llm_adapter_id"] == TGT_ADAPTER_IDS["gpt4"] def test_adopt_path_calls_sync_prompts_and_skips_import(): @@ -236,9 +250,8 @@ def test_adopt_path_calls_sync_prompts_and_skips_import(): # Target already has the tool with the same name. tgt.tools["tgt-existing"] = {"tool_name": "Invoice Extractor"} - remap = RemapTable() - _seed_adapter_remap(remap) - ctx = _ctx(src, tgt, remap=remap) + _seed_target_adapters(tgt) + ctx = _ctx(src, tgt) result = CustomToolPhase(ctx).run(MigrationReport()) @@ -263,9 +276,8 @@ def test_abort_on_name_conflict_raises(): _preload_source_tool(src, "src-tool-x", "Conflict") tgt.tools["tgt-existing"] = {"tool_name": "Conflict"} - remap = RemapTable() - _seed_adapter_remap(remap) - ctx = _ctx(src, tgt, remap=remap, on_name_conflict="abort") + _seed_target_adapters(tgt) + ctx = _ctx(src, tgt, on_name_conflict="abort") with pytest.raises(NameConflictError): CustomToolPhase(ctx).run(MigrationReport()) @@ -278,9 +290,8 @@ def test_dry_run_makes_no_writes(): src = FakeClient() tgt = FakeClient() _preload_source_tool(src, "src-tool-x", "T") - remap = RemapTable() - _seed_adapter_remap(remap) - ctx = _ctx(src, tgt, remap=remap, dry_run=True) + _seed_target_adapters(tgt) + ctx = _ctx(src, tgt, dry_run=True) result = CustomToolPhase(ctx).run(MigrationReport()) @@ -290,16 +301,14 @@ def test_dry_run_makes_no_writes(): assert tgt.export_tool_calls == [] -def test_missing_adapter_remap_fails_tool_cleanly(): +def test_missing_target_adapter_fails_tool_cleanly(): src = FakeClient() tgt = FakeClient() _preload_source_tool(src, "src-tool-x", "T") - # Only seed 3 of 4 adapters → x2text remap missing. - remap = RemapTable() - remap.record("adapter", SRC_LLM, TGT_LLM) - remap.record("adapter", SRC_EMB, TGT_EMB) - remap.record("adapter", SRC_VEC, TGT_VEC) - ctx = _ctx(src, tgt, remap=remap) + # Only seed 3 of 4 adapters → x2text lookup misses on target. + for name in ("gpt4", "ada-embed", "pgvector"): + tgt.adapters_by_name[name] = {"id": TGT_ADAPTER_IDS[name], "adapter_name": name} + ctx = _ctx(src, tgt) result = CustomToolPhase(ctx).run(MigrationReport()) From 0d237a2e5b2ed71380dfeb0962a79da1c3a50c73 Mon Sep 17 00:00:00 2001 From: Chandrasekharan M Date: Sun, 24 May 2026 04:13:46 +0530 Subject: [PATCH 11/25] feat(migration): Pipeline + APIDeployment phases MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Closes the last gap in the org-to-org data migration: ETL/TASK pipelines and API deployments. Both phases mirror the WorkflowPhase shape: get-or-create by name on target, FK rewrites via the workflow remap table, dry-run + abort support, idempotent re-run. Phase order extended to: ... workflow_endpoint -> pipeline -> api_deployment `api_deployment` requires endpoints to be configured before the serializer accepts it, so it must run after WorkflowEndpointPhase. PipelinePhase scope: ETL + TASK only. DEFAULT is dead v1 code; APP is a Streamlit-style deployment that doesn't fit the pipeline model. API key handling: backend auto-provisions one active key per pipeline/deployment on POST. Extra rotated source keys are NOT mirrored — UUIDs are server-generated (not settable) and operators should rotate post-migration anyway. Both phases log a WARNING when the source had more than one active key so the operator notices. Client surface added: - list_pipelines, get_pipeline, create_pipeline, update_pipeline - list_api_deployments, get_api_deployment, create_api_deployment, update_api_deployment - list_pipeline_keys, list_api_deployment_keys, create_api_key 13 new unit tests; full suite green (43 -> 56). Co-Authored-By: Claude Opus 4.7 --- src/unstract/migration/client.py | 84 +++++++ src/unstract/migration/orchestrator.py | 8 +- src/unstract/migration/phases/__init__.py | 4 + .../migration/phases/api_deployment.py | 140 +++++++++++ src/unstract/migration/phases/pipeline.py | 140 +++++++++++ tests/migration/test_api_deployment_phase.py | 182 ++++++++++++++ tests/migration/test_pipeline_phase.py | 222 ++++++++++++++++++ 7 files changed, 779 insertions(+), 1 deletion(-) create mode 100644 src/unstract/migration/phases/api_deployment.py create mode 100644 src/unstract/migration/phases/pipeline.py create mode 100644 tests/migration/test_api_deployment_phase.py create mode 100644 tests/migration/test_pipeline_phase.py diff --git a/src/unstract/migration/client.py b/src/unstract/migration/client.py index 99187e0..183cefd 100644 --- a/src/unstract/migration/client.py +++ b/src/unstract/migration/client.py @@ -353,3 +353,87 @@ def update_workflow_endpoint( return self._request( "PATCH", f"workflow/endpoint/{endpoint_id}/", json=payload ) + + # ----- pipelines (ETL / TASK) ----- + + def list_pipelines( + self, + *, + name: str | None = None, + pipeline_type: str | None = None, + ) -> list[dict[str, Any]]: + """List pipelines in this org, optionally filtered by exact name + and/or pipeline_type (``ETL`` / ``TASK`` / ``APP``). + """ + params: dict[str, Any] = {} + if name is not None: + params["pipeline_name"] = name + if pipeline_type is not None: + params["type"] = pipeline_type + result = self._request("GET", "pipeline/", params=params) + return result if isinstance(result, list) else result.get("results", []) + + def get_pipeline(self, pipeline_id: str) -> dict[str, Any]: + return self._request("GET", f"pipeline/{pipeline_id}/") + + def create_pipeline(self, payload: dict[str, Any]) -> dict[str, Any]: + """Create a pipeline. Backend force-sets ``active=True`` and auto-creates + a single active API key on the new pipeline. + """ + return self._request("POST", "pipeline/", json=payload) + + def update_pipeline( + self, pipeline_id: str, payload: dict[str, Any] + ) -> dict[str, Any]: + return self._request("PATCH", f"pipeline/{pipeline_id}/", json=payload) + + # ----- API deployments ----- + + def list_api_deployments( + self, + *, + api_name: str | None = None, + ) -> list[dict[str, Any]]: + """List API deployments in this org, optionally filtered by exact api_name.""" + params: dict[str, Any] = {} + if api_name is not None: + params["api_name"] = api_name + result = self._request("GET", "api/deployment/", params=params) + return result if isinstance(result, list) else result.get("results", []) + + def get_api_deployment(self, deployment_id: str) -> dict[str, Any]: + return self._request("GET", f"api/deployment/{deployment_id}/") + + def create_api_deployment(self, payload: dict[str, Any]) -> dict[str, Any]: + """Create an API deployment. Backend auto-creates a single active key + and returns it in the response under ``api_key``. + """ + return self._request("POST", "api/deployment/", json=payload) + + def update_api_deployment( + self, deployment_id: str, payload: dict[str, Any] + ) -> dict[str, Any]: + return self._request( + "PATCH", f"api/deployment/{deployment_id}/", json=payload + ) + + # ----- API keys (per pipeline / deployment) ----- + + def list_pipeline_keys(self, pipeline_id: str) -> list[dict[str, Any]]: + """List API keys belonging to a pipeline.""" + result = self._request("GET", f"api/keys/pipeline/{pipeline_id}/") + return result if isinstance(result, list) else result.get("results", []) + + def list_api_deployment_keys(self, deployment_id: str) -> list[dict[str, Any]]: + """List API keys belonging to an API deployment.""" + result = self._request("GET", f"api/keys/api/{deployment_id}/") + return result if isinstance(result, list) else result.get("results", []) + + def create_api_key(self, payload: dict[str, Any]) -> dict[str, Any]: + """Create an extra API key tied to a pipeline or deployment. + + Used to mirror non-default keys (e.g. an additional rotated key) + on the target. The ``api_key`` UUID itself is server-generated + and cannot be carried over from source. + """ + return self._request("POST", "api/keys/api/", json=payload) diff --git a/src/unstract/migration/orchestrator.py b/src/unstract/migration/orchestrator.py index 8a58ff2..5701d70 100644 --- a/src/unstract/migration/orchestrator.py +++ b/src/unstract/migration/orchestrator.py @@ -18,8 +18,10 @@ from unstract.migration.exceptions import MigrationError from unstract.migration.phases import ( AdapterPhase, + APIDeploymentPhase, ConnectorPhase, CustomToolPhase, + PipelinePhase, TagPhase, ToolInstancePhase, WorkflowEndpointPhase, @@ -33,7 +35,9 @@ # Strict dependency order. Each entry: (phase_name, phase_class). # Adapter, connector, tag are independent leaf phases. Downstream phases # (custom_tool, workflow, tool_instance, workflow_endpoint) land later -# and consume the remap entries these produce. +# and consume the remap entries these produce. Pipeline + api_deployment +# come last: both FK the workflow and api_deployment additionally +# requires endpoints to be configured before the serializer accepts it. PHASES: list[tuple[str, type[Phase]]] = [ ("adapter", AdapterPhase), ("connector", ConnectorPhase), @@ -42,6 +46,8 @@ ("workflow", WorkflowPhase), ("tool_instance", ToolInstancePhase), ("workflow_endpoint", WorkflowEndpointPhase), + ("pipeline", PipelinePhase), + ("api_deployment", APIDeploymentPhase), ] diff --git a/src/unstract/migration/phases/__init__.py b/src/unstract/migration/phases/__init__.py index 125b19b..bde8030 100644 --- a/src/unstract/migration/phases/__init__.py +++ b/src/unstract/migration/phases/__init__.py @@ -8,19 +8,23 @@ """ from unstract.migration.phases.adapter import AdapterPhase +from unstract.migration.phases.api_deployment import APIDeploymentPhase from unstract.migration.phases.base import Phase from unstract.migration.phases.connector import ConnectorPhase from unstract.migration.phases.custom_tool import CustomToolPhase +from unstract.migration.phases.pipeline import PipelinePhase from unstract.migration.phases.tag import TagPhase from unstract.migration.phases.tool_instance import ToolInstancePhase from unstract.migration.phases.workflow import WorkflowPhase from unstract.migration.phases.workflow_endpoint import WorkflowEndpointPhase __all__ = [ + "APIDeploymentPhase", "AdapterPhase", "ConnectorPhase", "CustomToolPhase", "Phase", + "PipelinePhase", "TagPhase", "ToolInstancePhase", "WorkflowEndpointPhase", diff --git a/src/unstract/migration/phases/api_deployment.py b/src/unstract/migration/phases/api_deployment.py new file mode 100644 index 0000000..e88beff --- /dev/null +++ b/src/unstract/migration/phases/api_deployment.py @@ -0,0 +1,140 @@ +"""Migrate API deployments from source org to target org. + +APIDeployment FKs ``workflow`` — remap via the WorkflowPhase table. +Backend enforces one active deployment per workflow and one +``api_name`` per organization, so adopt-by-name is the only safe +re-run strategy. + +On create the backend auto-provisions a single active API key and +returns it on the response. Extra rotated keys on the source are NOT +mirrored (server-generated UUIDs can't be preserved; rotate +post-migration). +""" + +from __future__ import annotations + +import logging +from typing import Any + +from unstract.migration.exceptions import NameConflictError +from unstract.migration.phases.base import Phase, build_post_payload +from unstract.migration.report import MigrationReport, PhaseResult +from unstract.migration.walker import remap_uuids + +logger = logging.getLogger(__name__) + +API_DEPLOYMENT_PATH = "api/deployment/" + + +class APIDeploymentPhase(Phase): + name = "api_deployment" + + def run(self, report: MigrationReport) -> PhaseResult: + result = report.get_phase(self.name) + try: + self._writable = self.ctx.target.get_post_schema(API_DEPLOYMENT_PATH) + except Exception as e: + logger.exception( + "Failed to fetch target POST schema for api_deployment: %s", e + ) + result.failed += 1 + result.errors.append(f"OPTIONS api_deployment: {e}") + return result + + try: + src_deployments = self.ctx.source.list_api_deployments() + except Exception as e: + logger.exception("Failed to list source api_deployments: %s", e) + result.failed += 1 + result.errors.append(f"list source api_deployments: {e}") + return result + + logger.info("Found %d source API deployment(s)", len(src_deployments)) + for src in src_deployments: + self._migrate_one(src, result) + return result + + def _migrate_one(self, src: dict[str, Any], result: PhaseResult) -> None: + api_name = src["api_name"] + src_id = src["id"] + src_wf_id = src.get("workflow") or src.get("workflow_id") + + if not src_wf_id: + logger.warning( + "source api_deployment '%s' has no workflow FK — skipping", api_name + ) + result.skipped += 1 + return + + tgt_wf_id = self.ctx.remap.resolve("workflow", src_wf_id) + if not tgt_wf_id: + logger.warning( + "no workflow remap for api_deployment '%s' (src workflow %s) — skipping", + api_name, src_wf_id, + ) + result.skipped += 1 + return + + try: + existing = self.ctx.target.list_api_deployments(api_name=api_name) + except Exception as e: + logger.exception( + "Failed to GET api_deployment %s on target: %s", api_name, e + ) + result.failed += 1 + result.errors.append(f"GET {api_name}: {e}") + return + + if existing: + tgt = existing[0] + if self.ctx.options.on_name_conflict == "abort": + raise NameConflictError( + f"api_deployment '{api_name}' already exists in target as {tgt['id']}" + ) + result.adopted += 1 + logger.info( + "adopted api_deployment '%s' src=%s -> tgt=%s", + api_name, src_id, tgt["id"], + ) + elif self.ctx.options.dry_run: + result.skipped += 1 + logger.info( + "[dry-run] would create api_deployment '%s' src=%s", api_name, src_id + ) + return + else: + remapped = remap_uuids(src, self.ctx.remap) + payload = build_post_payload(remapped, self._writable) + payload["workflow"] = tgt_wf_id + try: + tgt = self.ctx.target.create_api_deployment(payload) + except Exception as e: + logger.exception( + "Failed to create api_deployment %s: %s", api_name, e + ) + result.failed += 1 + result.errors.append(f"create {api_name}: {e}") + return + result.created += 1 + logger.info( + "created api_deployment '%s' src=%s -> tgt=%s", + api_name, src_id, tgt["id"], + ) + self._warn_if_extra_source_keys(src_id, api_name) + + self.ctx.remap.record("api_deployment", src_id, tgt["id"]) + + def _warn_if_extra_source_keys(self, src_deployment_id: str, name: str) -> None: + try: + keys = self.ctx.source.list_api_deployment_keys(src_deployment_id) + except Exception as e: + logger.debug("Could not list source keys for api_deployment %s: %s", name, e) + return + active = [k for k in keys if k.get("is_active")] + if len(active) > 1: + logger.warning( + "source api_deployment '%s' had %d active API keys; " + "target has only the auto-provisioned default — " + "re-create the rest manually if your clients depend on them", + name, len(active), + ) diff --git a/src/unstract/migration/phases/pipeline.py b/src/unstract/migration/phases/pipeline.py new file mode 100644 index 0000000..8a05487 --- /dev/null +++ b/src/unstract/migration/phases/pipeline.py @@ -0,0 +1,140 @@ +"""Migrate ETL/TASK pipelines from source org to target org. + +Pipelines FK ``workflow`` — the only entity remap needed. On create the +backend force-sets ``active=True`` and auto-provisions one active API +key per pipeline; if the source had additional rotated keys, those are +NOT mirrored (their UUIDs are server-generated and can't be preserved, +and operators rotate post-migration anyway). + +``DEFAULT`` (legacy) and ``APP`` pipeline types are skipped — DEFAULT is +dead code from the v1 era; APP is a Streamlit-style deployment whose +lifecycle isn't shaped like an ETL/TASK pipeline. +""" + +from __future__ import annotations + +import logging +from typing import Any + +from unstract.migration.exceptions import NameConflictError +from unstract.migration.phases.base import Phase, build_post_payload +from unstract.migration.report import MigrationReport, PhaseResult +from unstract.migration.walker import remap_uuids + +logger = logging.getLogger(__name__) + +PIPELINE_PATH = "pipeline/" +_MIGRATABLE_TYPES: frozenset[str] = frozenset({"ETL", "TASK"}) + + +class PipelinePhase(Phase): + name = "pipeline" + + def run(self, report: MigrationReport) -> PhaseResult: + result = report.get_phase(self.name) + try: + self._writable = self.ctx.target.get_post_schema(PIPELINE_PATH) + except Exception as e: + logger.exception("Failed to fetch target POST schema for pipeline: %s", e) + result.failed += 1 + result.errors.append(f"OPTIONS pipeline: {e}") + return result + + try: + src_pipelines = self.ctx.source.list_pipelines() + except Exception as e: + logger.exception("Failed to list source pipelines: %s", e) + result.failed += 1 + result.errors.append(f"list source pipelines: {e}") + return result + + migratable = [ + p for p in src_pipelines if p.get("pipeline_type") in _MIGRATABLE_TYPES + ] + skipped_types = len(src_pipelines) - len(migratable) + if skipped_types: + logger.info( + "Found %d source pipeline(s); skipping %d of unsupported type (DEFAULT/APP)", + len(src_pipelines), skipped_types, + ) + else: + logger.info("Found %d source pipeline(s)", len(src_pipelines)) + + for src in migratable: + self._migrate_one(src, result) + return result + + def _migrate_one(self, src: dict[str, Any], result: PhaseResult) -> None: + name = src["pipeline_name"] + src_id = src["id"] + src_wf_id = src.get("workflow") or src.get("workflow_id") + + if not src_wf_id: + logger.warning("source pipeline '%s' has no workflow FK — skipping", name) + result.skipped += 1 + return + + tgt_wf_id = self.ctx.remap.resolve("workflow", src_wf_id) + if not tgt_wf_id: + logger.warning( + "no workflow remap for pipeline '%s' (src workflow %s) — skipping", + name, src_wf_id, + ) + result.skipped += 1 + return + + try: + existing = self.ctx.target.list_pipelines(name=name) + except Exception as e: + logger.exception("Failed to GET pipeline %s on target: %s", name, e) + result.failed += 1 + result.errors.append(f"GET {name}: {e}") + return + + if existing: + tgt = existing[0] + if self.ctx.options.on_name_conflict == "abort": + raise NameConflictError( + f"pipeline '{name}' already exists in target as {tgt['id']}" + ) + result.adopted += 1 + logger.info( + "adopted pipeline '%s' src=%s -> tgt=%s", name, src_id, tgt["id"] + ) + elif self.ctx.options.dry_run: + result.skipped += 1 + logger.info("[dry-run] would create pipeline '%s' src=%s", name, src_id) + return + else: + remapped = remap_uuids(src, self.ctx.remap) + payload = build_post_payload(remapped, self._writable) + payload["workflow"] = tgt_wf_id + try: + tgt = self.ctx.target.create_pipeline(payload) + except Exception as e: + logger.exception("Failed to create pipeline %s: %s", name, e) + result.failed += 1 + result.errors.append(f"create {name}: {e}") + return + result.created += 1 + logger.info( + "created pipeline '%s' src=%s -> tgt=%s", name, src_id, tgt["id"] + ) + self._warn_if_extra_source_keys(src_id, name) + + self.ctx.remap.record("pipeline", src_id, tgt["id"]) + + def _warn_if_extra_source_keys(self, src_pipeline_id: str, name: str) -> None: + try: + keys = self.ctx.source.list_pipeline_keys(src_pipeline_id) + except Exception as e: + logger.debug("Could not list source keys for pipeline %s: %s", name, e) + return + active = [k for k in keys if k.get("is_active")] + if len(active) > 1: + logger.warning( + "source pipeline '%s' had %d active API keys; " + "target has only the auto-provisioned default — " + "re-create the rest manually if your clients depend on them", + name, len(active), + ) diff --git a/tests/migration/test_api_deployment_phase.py b/tests/migration/test_api_deployment_phase.py new file mode 100644 index 0000000..a80900c --- /dev/null +++ b/tests/migration/test_api_deployment_phase.py @@ -0,0 +1,182 @@ +"""Tests for ``APIDeploymentPhase``. + +Coverage: +- happy path: source api_deployments created with workflow FK remapped. +- adopt by api_name on existing target deployment. +- skipped when workflow remap missing. +- dry-run is a no-op. +- abort raises ``NameConflictError``. +- extra source keys produce a warning, never a failure. +""" + +from __future__ import annotations + +import logging + +import pytest + +from unstract.migration.context import ( + MigrationContext, + MigrationOptions, + RemapTable, +) +from unstract.migration.exceptions import NameConflictError +from unstract.migration.phases.api_deployment import APIDeploymentPhase +from unstract.migration.report import MigrationReport + + +API_DEPLOYMENT_POST_SCHEMA = frozenset( + { + "display_name", + "description", + "workflow", + "is_active", + "api_name", + "shared_users", + "shared_to_org", + } +) + + +class FakeClient: + def __init__(self, deployments: list[dict] | None = None): + self.deployments: list[dict] = list(deployments or []) + self.posts: list[dict] = [] + self.keys_by_deployment: dict[str, list[dict]] = {} + self._next = 1 + + def get_post_schema(self, entity_path: str) -> frozenset[str]: + return API_DEPLOYMENT_POST_SCHEMA + + def list_api_deployments(self, *, api_name: str | None = None): + result = self.deployments + if api_name is not None: + result = [d for d in result if d["api_name"] == api_name] + return list(result) + + def create_api_deployment(self, payload: dict) -> dict: + new = dict(payload) + new["id"] = f"tgt-dep-{self._next:04d}" + new["api_key"] = f"key-{self._next:04d}" + self._next += 1 + self.deployments.append(new) + self.posts.append(new) + return new + + def list_api_deployment_keys(self, deployment_id: str) -> list[dict]: + return list(self.keys_by_deployment.get(deployment_id, [])) + + +def _src_deployment( + id_: str, api_name: str, workflow_id: str, *, display_name: str | None = None +) -> dict: + return { + "id": id_, + "api_name": api_name, + "display_name": display_name or api_name, + "description": f"{api_name} desc", + "workflow": workflow_id, + "workflow_id": workflow_id, + "is_active": True, + "shared_users": [], + "shared_to_org": False, + } + + +def _ctx(source, target, *, remap=None, **opt_overrides): + return MigrationContext( + source=source, + target=target, + options=MigrationOptions(**opt_overrides), + remap=remap or RemapTable(), + ) + + +def test_happy_path_creates_deployment_with_remapped_workflow(): + src = FakeClient([_src_deployment("src-dep-1", "invoices_api", "wf-src-1")]) + tgt = FakeClient() + remap = RemapTable() + remap.record("workflow", "wf-src-1", "wf-tgt-1") + ctx = _ctx(src, tgt, remap=remap) + + result = APIDeploymentPhase(ctx).run(MigrationReport()) + + assert result.created == 1 + assert result.failed == 0 + posted = tgt.posts[0] + assert posted["api_name"] == "invoices_api" + assert posted["workflow"] == "wf-tgt-1" + assert ctx.remap.resolve("api_deployment", "src-dep-1") == posted["id"] + + +def test_adopts_existing_deployment_by_api_name(): + src = FakeClient([_src_deployment("src-dep-1", "invoices_api", "wf-src-1")]) + tgt = FakeClient( + [{"id": "tgt-existing", "api_name": "invoices_api"}] + ) + remap = RemapTable() + remap.record("workflow", "wf-src-1", "wf-tgt-1") + ctx = _ctx(src, tgt, remap=remap) + + result = APIDeploymentPhase(ctx).run(MigrationReport()) + + assert result.adopted == 1 + assert result.created == 0 + assert tgt.posts == [] + assert ctx.remap.resolve("api_deployment", "src-dep-1") == "tgt-existing" + + +def test_skipped_when_workflow_remap_missing(): + src = FakeClient([_src_deployment("src-dep-1", "orphan", "wf-src-1")]) + tgt = FakeClient() + ctx = _ctx(src, tgt) # No workflow remap. + + result = APIDeploymentPhase(ctx).run(MigrationReport()) + + assert result.skipped == 1 + assert tgt.posts == [] + + +def test_dry_run_makes_no_writes(): + src = FakeClient([_src_deployment("src-dep-1", "invoices_api", "wf-src-1")]) + tgt = FakeClient() + remap = RemapTable() + remap.record("workflow", "wf-src-1", "wf-tgt-1") + ctx = _ctx(src, tgt, remap=remap, dry_run=True) + + result = APIDeploymentPhase(ctx).run(MigrationReport()) + + assert result.skipped == 1 + assert tgt.posts == [] + + +def test_abort_on_name_conflict_raises(): + src = FakeClient([_src_deployment("src-dep-1", "invoices_api", "wf-src-1")]) + tgt = FakeClient([{"id": "tgt-existing", "api_name": "invoices_api"}]) + remap = RemapTable() + remap.record("workflow", "wf-src-1", "wf-tgt-1") + ctx = _ctx(src, tgt, remap=remap, on_name_conflict="abort") + + with pytest.raises(NameConflictError): + APIDeploymentPhase(ctx).run(MigrationReport()) + + +def test_extra_source_keys_log_warning_not_failure(caplog): + src = FakeClient([_src_deployment("src-dep-1", "invoices_api", "wf-src-1")]) + src.keys_by_deployment["src-dep-1"] = [ + {"id": "k1", "is_active": True}, + {"id": "k2", "is_active": True}, + ] + tgt = FakeClient() + remap = RemapTable() + remap.record("workflow", "wf-src-1", "wf-tgt-1") + ctx = _ctx(src, tgt, remap=remap) + + with caplog.at_level( + logging.WARNING, logger="unstract.migration.phases.api_deployment" + ): + result = APIDeploymentPhase(ctx).run(MigrationReport()) + + assert result.created == 1 + assert result.failed == 0 + assert any("2 active API keys" in r.message for r in caplog.records) diff --git a/tests/migration/test_pipeline_phase.py b/tests/migration/test_pipeline_phase.py new file mode 100644 index 0000000..a064628 --- /dev/null +++ b/tests/migration/test_pipeline_phase.py @@ -0,0 +1,222 @@ +"""Tests for ``PipelinePhase``. + +Coverage: +- happy path: source ETL/TASK pipelines created with workflow FK remapped. +- DEFAULT and APP types are skipped (out of migration scope). +- adopt path on name conflict. +- skipped when workflow remap missing. +- dry-run is a no-op. +- abort raises ``NameConflictError``. +- extra source keys produce a warning, never a failure. +""" + +from __future__ import annotations + +import logging + +import pytest + +from unstract.migration.context import ( + MigrationContext, + MigrationOptions, + RemapTable, +) +from unstract.migration.exceptions import NameConflictError +from unstract.migration.phases.pipeline import PipelinePhase +from unstract.migration.report import MigrationReport + + +PIPELINE_POST_SCHEMA = frozenset( + { + "pipeline_name", + "workflow", + "pipeline_type", + "cron_string", + "app_id", + "app_icon", + "app_url", + "access_control_bundle_id", + "shared_users", + "shared_to_org", + } +) + + +class FakeClient: + def __init__(self, pipelines: list[dict] | None = None): + self.pipelines: list[dict] = list(pipelines or []) + self.posts: list[dict] = [] + self.keys_by_pipeline: dict[str, list[dict]] = {} + self._next = 1 + + def get_post_schema(self, entity_path: str) -> frozenset[str]: + return PIPELINE_POST_SCHEMA + + def list_pipelines( + self, *, name: str | None = None, pipeline_type: str | None = None + ): + result = self.pipelines + if name is not None: + result = [p for p in result if p["pipeline_name"] == name] + if pipeline_type is not None: + result = [p for p in result if p.get("pipeline_type") == pipeline_type] + return list(result) + + def create_pipeline(self, payload: dict) -> dict: + new = dict(payload) + new["id"] = f"tgt-pipeline-{self._next:04d}" + self._next += 1 + self.pipelines.append(new) + self.posts.append(new) + return new + + def list_pipeline_keys(self, pipeline_id: str) -> list[dict]: + return list(self.keys_by_pipeline.get(pipeline_id, [])) + + +def _src_pipeline( + id_: str, + name: str, + workflow_id: str, + *, + pipeline_type: str = "ETL", + cron_string: str | None = None, +) -> dict: + return { + "id": id_, + "pipeline_name": name, + "workflow": workflow_id, + "workflow_id": workflow_id, + "workflow_name": "wf", + "pipeline_type": pipeline_type, + "active": True, + "scheduled": cron_string is not None, + "cron_string": cron_string, + "app_id": None, + "app_icon": None, + "app_url": None, + "access_control_bundle_id": None, + "shared_users": [], + "shared_to_org": False, + } + + +def _ctx(source, target, *, remap=None, **opt_overrides): + return MigrationContext( + source=source, + target=target, + options=MigrationOptions(**opt_overrides), + remap=remap or RemapTable(), + ) + + +def test_happy_path_creates_pipeline_with_remapped_workflow(): + src = FakeClient( + [_src_pipeline("src-pl-1", "Daily Invoices", "wf-src-1")] + ) + tgt = FakeClient() + remap = RemapTable() + remap.record("workflow", "wf-src-1", "wf-tgt-1") + ctx = _ctx(src, tgt, remap=remap) + + result = PipelinePhase(ctx).run(MigrationReport()) + + assert result.created == 1 + assert result.failed == 0 + posted = tgt.posts[0] + assert posted["pipeline_name"] == "Daily Invoices" + assert posted["workflow"] == "wf-tgt-1" + assert ctx.remap.resolve("pipeline", "src-pl-1") == posted["id"] + + +def test_default_and_app_pipeline_types_are_skipped(): + src = FakeClient( + [ + _src_pipeline("src-1", "default-legacy", "wf-src-1", pipeline_type="DEFAULT"), + _src_pipeline("src-2", "streamlit-app", "wf-src-1", pipeline_type="APP"), + _src_pipeline("src-3", "real-etl", "wf-src-1", pipeline_type="ETL"), + ] + ) + tgt = FakeClient() + remap = RemapTable() + remap.record("workflow", "wf-src-1", "wf-tgt-1") + ctx = _ctx(src, tgt, remap=remap) + + result = PipelinePhase(ctx).run(MigrationReport()) + + assert result.created == 1 + assert len(tgt.posts) == 1 + assert tgt.posts[0]["pipeline_name"] == "real-etl" + + +def test_adopts_existing_pipeline_by_name(): + src = FakeClient([_src_pipeline("src-pl-1", "Daily Invoices", "wf-src-1")]) + tgt = FakeClient( + [{"id": "tgt-existing", "pipeline_name": "Daily Invoices"}] + ) + remap = RemapTable() + remap.record("workflow", "wf-src-1", "wf-tgt-1") + ctx = _ctx(src, tgt, remap=remap) + + result = PipelinePhase(ctx).run(MigrationReport()) + + assert result.adopted == 1 + assert result.created == 0 + assert tgt.posts == [] + assert ctx.remap.resolve("pipeline", "src-pl-1") == "tgt-existing" + + +def test_skipped_when_workflow_remap_missing(): + src = FakeClient([_src_pipeline("src-pl-1", "Orphan", "wf-src-1")]) + tgt = FakeClient() + ctx = _ctx(src, tgt) # No workflow remap. + + result = PipelinePhase(ctx).run(MigrationReport()) + + assert result.skipped == 1 + assert result.failed == 0 + assert tgt.posts == [] + + +def test_dry_run_makes_no_writes(): + src = FakeClient([_src_pipeline("src-pl-1", "Daily Invoices", "wf-src-1")]) + tgt = FakeClient() + remap = RemapTable() + remap.record("workflow", "wf-src-1", "wf-tgt-1") + ctx = _ctx(src, tgt, remap=remap, dry_run=True) + + result = PipelinePhase(ctx).run(MigrationReport()) + + assert result.skipped == 1 + assert tgt.posts == [] + + +def test_abort_on_name_conflict_raises(): + src = FakeClient([_src_pipeline("src-pl-1", "Daily Invoices", "wf-src-1")]) + tgt = FakeClient([{"id": "tgt-existing", "pipeline_name": "Daily Invoices"}]) + remap = RemapTable() + remap.record("workflow", "wf-src-1", "wf-tgt-1") + ctx = _ctx(src, tgt, remap=remap, on_name_conflict="abort") + + with pytest.raises(NameConflictError): + PipelinePhase(ctx).run(MigrationReport()) + + +def test_extra_source_keys_log_warning_not_failure(caplog): + src = FakeClient([_src_pipeline("src-pl-1", "Daily Invoices", "wf-src-1")]) + src.keys_by_pipeline["src-pl-1"] = [ + {"id": "k1", "is_active": True}, + {"id": "k2", "is_active": True}, + {"id": "k3", "is_active": False}, + ] + tgt = FakeClient() + remap = RemapTable() + remap.record("workflow", "wf-src-1", "wf-tgt-1") + ctx = _ctx(src, tgt, remap=remap) + + with caplog.at_level(logging.WARNING, logger="unstract.migration.phases.pipeline"): + result = PipelinePhase(ctx).run(MigrationReport()) + + assert result.created == 1 + assert result.failed == 0 + assert any("2 active API keys" in r.message for r in caplog.records) From 05055ce8eae16e7e5c3302e1188f9251b2557367 Mon Sep 17 00:00:00 2001 From: Chandrasekharan M Date: Sun, 24 May 2026 13:18:11 +0530 Subject: [PATCH 12/25] feat(migration): FilesPhase for Prompt Studio document corpus [UN-3479] MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds a `files` phase that moves Prompt Studio document files between orgs using the existing Platform API endpoints — no new BE surface. - Default mode (`--file-strategy=platform_api`): lists target DM rows once per tool for idempotency, downloads each missing source file via `fetch_contents_ide`, decodes per mime, and POSTs as multipart through `upload_for_ide`. - Skip mode (`--skip-files`): metadata only; source filenames go into `MigrationReport.skipped_files` for operator-driven UI re-upload. - Mixed-mode reporting: files above `--max-file-size` (default 25MB) land in `oversize_files`; mime types the BE endpoint can't round-trip losslessly (Excel placeholder, etc.) land in `unsupported_files`. Sibling files always continue — phase never aborts on file-level issues. Transport-level failures land in `failed_files`. - Idempotency-only retries on 5xx + transient connection errors. - Wired after `custom_tool` in the orchestrator (consumes its remap). Report grows four new typed lists (`uploaded_files`, `skipped_files`, `oversize_files`, `unsupported_files`, `failed_files`) plus end-of-phase rendering in both rich and plain modes. 12 new unit tests cover happy path, idempotency skip, oversize, unsupported mime, skip strategy, dry-run, retry on 5xx, missing custom_tool remap, per-tool source-list failure isolation, upload failure capture, and parametrised text/csv + text/plain round-trips. --- docs/internal/files-migration-plan.md | 355 ++++++++++++++++++++ src/unstract/migration/README.md | 188 +++++++++++ src/unstract/migration/cli.py | 60 +++- src/unstract/migration/client.py | 48 +++ src/unstract/migration/context.py | 7 + src/unstract/migration/orchestrator.py | 2 + src/unstract/migration/phases/__init__.py | 2 + src/unstract/migration/phases/files.py | 325 ++++++++++++++++++ src/unstract/migration/report.py | 62 ++++ tests/migration/test_files_phase.py | 391 ++++++++++++++++++++++ 10 files changed, 1439 insertions(+), 1 deletion(-) create mode 100644 docs/internal/files-migration-plan.md create mode 100644 src/unstract/migration/README.md create mode 100644 src/unstract/migration/phases/files.py create mode 100644 tests/migration/test_files_phase.py diff --git a/docs/internal/files-migration-plan.md b/docs/internal/files-migration-plan.md new file mode 100644 index 0000000..8590778 --- /dev/null +++ b/docs/internal/files-migration-plan.md @@ -0,0 +1,355 @@ +# Files-Migration Phase — Implementation Plan + +**Status:** draft, revised 2026-05-24 (post product-owner review). Not a KB doc. Authoritative runbook for the SDK implementor agent (`odm-sdk-impl`) to land the `files` phase end-to-end. + +**Constraints from product owner:** +- **No new backend Platform API endpoints.** Use what exists today. +- **No storage-backend direct copy** (per-provider CLI + credentials handoff is too painful to maintain across customer environments). +- **No signed-URL relay** (would depend on new BE endpoints). +- Backend changes limited to non-API refactors (streaming write helper) that benefit the FE simultaneously. + +**Net scope: Platform-API uploads only. Oversize files are reported, not aborted on — operator handles those via UI re-upload.** + +--- + +## 0. What this phase moves + +Prompt Studio document files only. The user-uploaded test corpus per `CustomTool`. Storage path on both ends: + +``` +{PERMANENT_REMOTE_STORAGE}/{REMOTE_PROMPT_STUDIO_FILE_PATH}/{org_id}/{user_id}/{tool_id}/ +``` + +Server-generated subdirs (`extract/`, `summarize/`, `converted/`) are **out of scope** — regenerated on target on first index/summarize/preview. Don't migrate them. + +`DocumentManager` rows (`prompt_studio_document_manager_v2.DocumentManager`) are the DB-side mirror of the files. **The Platform API upload endpoint creates the `DocumentManager` row as a side effect of the file upload itself** — see §10a. `CustomToolPhase` does not create them. + +## 1. Where the phase sits + +Dependency order: + +``` +… → custom_tool → files → workflow → tool_instance → workflow_endpoint → api_deployment → pipeline +``` + +Runs after `custom_tool` (which mints the `src_tool_id → tgt_tool_id` remap and creates the CustomTool + ProfileManager + prompts on target) and before `workflow` (workflows can reference indexed prompt outputs, but `workflow` creation itself doesn't need the bytes present). + +## 2. Strategy dispatch + +Selectable via `MigrationOptions.file_strategy`: + +| Value | Behavior | +|-------|----------| +| `"platform_api"` (default) | Uploads through existing Platform API endpoints, capped per file. Files over cap are reported, not aborted on. | +| `"skip"` | Operator re-uploads manually via UI on target; SDK does pure metadata migration. | + +**Sizing rationale (US prod data, 2026-05-24):** 18,890 orgs with files, p99=14 files/org, max external=529 (buildwithstudio). Moody's biggest single user-org=99 files. At ~3 sec per file roundtrip, Moody's worst case = ~5 min, buildwithstudio-class = ~25 min. Sequential through the Platform API is in budget. File sizes not in DB — sampling recommended but not blocking; cap is the safety valve. + +## 3. Default upload flow + +### Mechanism + +Use the existing endpoints exactly as they are: + +- **Download from source:** `GET /api/v1//prompt-studio/file//?file_name=` — returns `{"data": , "mime_type": "..."}` per `PromptStudioFileHelper.fetch_file_contents`. +- **Upload to target:** `POST /api/v1//prompt-studio/file//` with multipart body — see `upload_for_ide` signature. +- **List target files for idempotency:** `GET /api/v1//prompt-document/?tool_id=` — returns `[{document_id, document_name, tool}, ...]`; the `document_name` field is what the pre-check matches against. + +### Memory + size discipline + +- **Hard per-file cap.** Default `max_file_size = 25 * 1024 * 1024` (25 MB). Above → skip that file, log to `MigrationReport.oversize_files`, continue siblings. Operator can override via CLI flag (`--max-file-size`). +- **Cap rationale.** 10 MB is too tight (a 50-page scanned PDF can be 15-30 MB); 50 MB risks cloud-worker memory pressure given the base64+JSON-wrap overhead on download (~3x the file size in peak memory). 25 MB covers typical Prompt Studio test docs while staying within a safe worker memory budget. +- **No reliable pre-flight size check.** Existing endpoints don't expose size separately from full payload. Workaround: enforce cap at request time after download — the SDK only knows the size after it has the bytes in memory. Acceptable given the cap is small enough that a single oversize-download spike is bounded. +- **Concurrency = 1** per phase. Do not raise — single-file-at-a-time is what keeps the cloud worker hold bounded. +- **Retry** transient HTTP errors (5xx, ConnectionError, Timeout) with exponential backoff, max 3 attempts. +- **No body logging** in either direction. + +### Idempotency caching (important for re-run cost) + +Fetch the target tool's file list **once per tool**, not per-file. With 529-file corpora (buildwithstudio scale), per-entity name lookups would balloon to 529 list calls just for skip-checks. Pattern: + +```python +def migrate_tool_files(src_tool_id, tgt_tool_id): + tgt_filenames = set(target.list_tool_filenames(tgt_tool_id)) # 1 call + src_files = source.list_tool_filenames(src_tool_id) # 1 call + for fname in src_files: + if fname in tgt_filenames: + report.skipped_existing += 1 + continue + migrate_file(src_tool_id, tgt_tool_id, fname, tgt_filenames) +``` + +This keeps re-run cost at ~2 HTTP calls per tool regardless of file count. Moody's full re-run (10 user-orgs × ~10 tools × 2 calls) ≈ 200 quick HTTP calls ≈ 10-20 sec total for the files phase. + +### Per-file flow + +```python +def migrate_file(src_tool_id, tgt_tool_id, file_name): + # idempotency: skip if name already present on target + if file_name in target_filenames_for(tgt_tool_id): + report.skipped += 1 + return + + # download — full file in memory (existing endpoint constraint) + resp = source.platform_api.get( + f"/{src_org_slug}/prompt-studio/file/{src_tool_id}/", + params={"file_name": file_name}, + ) + payload = resp.json() + mime = payload["mime_type"] + data_field = payload["data"] + + if mime == "application/pdf": + raw = base64.b64decode(data_field) + elif mime in ("text/plain", "text/csv"): + raw = data_field.encode("utf-8") + elif mime.startswith("application/vnd.ms-excel") or mime.startswith("application/vnd.openxmlformats"): + raw = base64.b64decode(data_field) # verify against helper's actual branch + else: + report.warnings.append(f"{file_name}: unknown mime '{mime}', skipping") + return + + if len(raw) > options.max_file_size: + report.oversize_files.append({ + "tool_id": tgt_tool_id, + "tool_name": tgt_tool_name, + "file_name": file_name, + "size_bytes": len(raw), + "cap_bytes": options.max_file_size, + }) + return # oversize → operator handles via UI re-upload + + # upload as multipart + target.platform_api.post( + f"/{tgt_org_slug}/prompt-studio/file/{tgt_tool_id}/", + files={"file": (file_name, raw, mime)}, + ) + report.uploaded += 1 +``` + +Verify mime-branch coverage against the actual `fetch_file_contents` helper (lines 167-188 of `prompt_studio_file_helper.py`) and extend if new branches landed. + +### Cloud safety check + +Before phase starts, log: + +``` +Files phase about to run via Platform API: + source: target: + tools: cap: 25 MB/file + estimated cloud-worker hold: ~1 worker × +``` + +Operator can abort here if running against cloud during peak hours. + +### Idempotency + +Name-based, listed-target side. No hash check. + +## 4. Skip mode + +Pure metadata migration. Phase prints, per migrated tool: + +``` +files: skipped (--skip-files). Documents must be re-uploaded on target via the IDE. + Navigate to each migrated tool on the target deployment and re-upload via the file manager pane. + + - tool 'invoice_extractor' (tgt_id=abc-123): 12 files expected + sample.pdf, contract_q1.pdf, ... + - tool 'receipt_classifier' (tgt_id=def-456): 3 files expected + receipt1.pdf, receipt2.pdf, receipt3.pdf +``` + +**No "destination path" is needed** — the operator clicks "upload" in the target UI's file manager, picks the file from their local disk, and the backend's `upload_for_ide` constructs the storage path from the calling user's session and the target `tool_id`. The operator never touches storage paths. + +**What the operator needs in their hand to do skip-files re-uploads:** the source bytes themselves. Options: +- They already have local copies (they originally uploaded these from their own machine). +- They download from source deployment UI one file at a time before re-uploading to target. Painful at >10-20 files; this is why uploads via Platform API are the default. + +`MigrationReport.skipped_files` carries the full list `[{tool_id, tool_name, file_name, source_org_slug, source_tool_id}, ...]` so external tooling (or a future helper script) can drive a download-then-upload loop using the source deployment's credentials. + +## 4a. Oversize file handling + +When a file exceeds `max_file_size` during the default flow, the SDK does **not** abort. It: + +1. Logs the file under `MigrationReport.oversize_files` with `{tool_id, tool_name, file_name, size_bytes, cap_bytes}`. +2. Continues with sibling files in the same tool. +3. At end-of-phase, prints a "files requiring manual upload" section listing the oversize subset. + +A corpus where 95% of files fit under cap therefore gets 95% auto-migrated and 5% surfaced for manual UI re-upload, in one run. + +## 6. Backend buffering side-quest (separate commit, not blocking the phase) + +The product owner has approved buffering the upload side in the BE because it benefits the FE concurrent-upload path too. **Constraint:** no new endpoints, no behavior change. Just internal refactor. + +### Change + +In `backend/utils/file_storage/helpers/prompt_studio_file_helper.py:upload_for_ide`, replace: + +```python +fs_instance.write( + path=file_path, + mode="wb", + data=file_data if isinstance(file_data, bytes) else file_data.read(), +) +``` + +with chunked streaming when `file_data` is an UploadedFile: + +```python +if isinstance(file_data, bytes): + fs_instance.write(path=file_path, mode="wb", data=file_data) +else: + with fs_instance.open(file_path, mode="wb", block_size=8 * 1024 * 1024) as out: + for chunk in file_data.chunks(chunk_size=8 * 1024 * 1024): + out.write(chunk) +``` + +Apply identical change to `upload_converted_for_ide` (same shape, different path). + +### Regression risks to verify + +1. **`block_size` must be set explicitly** — fsspec defaults vary per provider (GCS, S3, MinIO, Azure). Without it the implementation may buffer to memory until much larger thresholds. +2. **Partial-write cleanup on failure.** Wrap the streaming write in `try`/`except`; on exception, call the underlying multipart-abort if available (`fs.cancel(...)` for GCS resumable; `abort_multipart_upload` for S3 — exposed via `fs.fs.abort_multipart_upload` on s3fs). Document if a provider doesn't support clean abort; aged-out incomplete multipart uploads cost storage money. +3. **MIME detection.** `fetch_file_contents` does `fs.mime_type(...)` which is already partial-read-capable but tests should cover edge file types (Office docs, oddly-headered PDFs). +4. **`isinstance(file_data, bytes)` branch preserved** — some callers pass raw bytes; that path stays single-shot. +5. **Test coverage.** Add streaming-specific tests under `unstract/sdk1/tests/file_storage/` mirroring existing single-shot tests. Confirm peak memory bounded via `tracemalloc` snapshot in test. +6. **No change to `fetch_contents_ide`.** Constraint says no new BE APIs and FE consumes the existing response shape (base64 in JSON). Leave it alone. + +### Out of scope for this commit + +- Download streaming — would change `fetch_contents_ide` contract; product owner ruled out. +- New raw-download endpoint — same constraint. +- Resumability on the upload side — would also require new endpoint surface. + +## 7. SDK package layout + +``` +src/unstract/migration/ + ├── phases/ + │ └── files.py # FilesPhase: upload via Platform API with oversize reporting, or skip + ├── file_transport/ + │ ├── __init__.py + │ └── platform_api.py # download (base64-decode) + upload (multipart) using existing endpoints + └── ... (existing phases unchanged) +``` + +No direct-storage-copy or signed-URL transport modules — those approaches are out of scope. + +## 8. `MigrationOptions` additions + +```python +@dataclass +class MigrationOptions: + # ... existing fields ... + file_strategy: Literal["platform_api", "skip"] = "platform_api" + max_file_size: int = 25 * 1024 * 1024 # 25 MB; oversize files are reported, not uploaded +``` + +CLI flags: + +``` +--file-strategy {platform_api,skip} # default: platform_api +--max-file-size 25MB # accepts human-readable sizes +--skip-files # alias for --file-strategy=skip +``` + +No `auto` dispatch needed — only one byte-moving strategy exists. + +## 9. Test plan + +| Test | Mode | What it proves | +|------|------|----------------| +| 5 MB PDF via Platform API | default | round-trip base64+multipart works against existing helpers | +| 30 MB PDF via Platform API | default | cap fires; entry appears in `oversize_files`; sibling files continue | +| Tool with 10 files (mix of small + 1 oversize) | default | 9 migrated, 1 reported for manual upload, run exits 0 | +| Re-run after success | default | name-based idempotency: target file present → skip | +| `--skip-files` | skip | pure metadata migration; report lists tools + filenames (no storage paths) | +| `MigrationReport.skipped_files` schema | skip | shape matches `[{tool_id, tool_name, file_name, source_org_slug, source_tool_id}]` | +| `MigrationReport.oversize_files` schema | default | shape matches `[{tool_id, tool_name, file_name, size_bytes, cap_bytes}]` | +| Streaming upload backend refactor: 100 MB PDF via UI | (BE) | peak RSS bounded to ~chunk_size; file is byte-identical | +| Streaming upload + network failure mid-stream | (BE) | partial multipart aborted; no orphan upload in bucket | +| Default mode concurrent with FE user uploading | default | both succeed; no worker starvation observable | +| Moody's-scale run (~99 files, all under cap) | default | completes in <10 min; no failures; report counts match input | + +## 10. Acceptance + +- Adapter, connector, tag, custom_tool, prompts, profile_managers, prompt_registry phases unchanged. +- New `files` phase wired into `migrate()` orchestrator after `custom_tool`. +- Local smoke: fresh target org, run default mode against a tool with 3 small PDFs. Re-run: 3 skips, 0 failures. +- Local smoke: same scenario with one 30 MB PDF added → 3 uploaded, 1 in `oversize_files`, exit 0. +- Local smoke: `--skip-files` against same setup → 0 uploaded, all 4 listed in `skipped_files` with tool + filename. +- BE refactor commit (separate from SDK commit) lands on `feat/org-migration-platform-api-gaps`. SDK commit lands on `feat/org-migration`. +- `MigrationReport` exposes: `uploaded_files`, `skipped_files`, `oversize_files`, `failed_files` (each a typed list with tool_id + tool_name + file_name minimum). + +## 10a. Idempotency model — pre-check is load-bearing for correctness + +`upload_for_ide` (`views.py:1009`) is **not idempotent**, and in fact actively errors on retries. Order of operations: + +1. `PromptStudioFileHelper.upload_for_ide(...)` — write file to storage. Overwrites on collision (storage-idempotent). +2. `PromptStudioDocumentHelper.create(tool_id, document_name)` — unconditional `DocumentManager.objects.create`. + +The `DocumentManager` model carries `UniqueConstraint(fields=["document_name", "tool"])`. So calling `upload_for_ide` twice for the same `(tool, filename)` overwrites the file in storage, then raises `IntegrityError` on the second create — propagates as 500. **The SDK MUST pre-check; otherwise re-runs error out.** + +For files specifically, the pattern is: + +```python +def migrate_tool_files(src_tool_id, tgt_tool_id): + tgt_filenames = set(target.list_dm_rows(tgt_tool_id)) + for src_fname in source.list_dm_rows(src_tool_id): + if src_fname in tgt_filenames: + continue + upload_file(src_tool_id, tgt_tool_id, src_fname) +``` + +This guarantees the SDK never invokes the non-idempotent upload twice. The duplicate-DM-row outcome (well, the 500 — see above) only materializes if **something other than the SDK** invokes the upload endpoint with a name the SDK has already migrated. + +### Important difference from earlier draft of this plan + +This plan previously claimed `CustomToolPhase` creates `DocumentManager` rows on target. **It does not.** `import_project` only creates `CustomTool` + `ProfileManager` + prompts; DM rows are only created by `upload_for_ide`. Therefore: + +- Before the files phase runs on a fresh target tool, the DM list is empty. Pre-check returns ∅. SDK uploads all files. ✅ +- On re-run, pre-check returns the SDK's own previously-created rows. SDK skips them. ✅ +- For an oversize / skipped file the operator re-uploads via UI: no SDK DM row exists for that filename → UI creates a single row, no duplicate. ✅ + +The duplicate-DM-row problem from prior plan revisions is **not reachable** through SDK flows. It only manifests if a user uploads via UI mid-migration (between the SDK's list call and its upload call for that same filename) — mitigation is to run migrations in low-activity windows. + +### Crash semantics + +The file-first-then-DM order is fortunate: + +| Failure point | Target state | Re-run outcome | +|---------------|--------------|----------------| +| File write fails | No file, no DM row | Pre-check sees no DM → retry → clean | +| File written, DM create fails | File on disk, no DM row | Pre-check sees no DM → retry → file overwritten, DM created. **Self-healing** | +| Both succeed, SDK dies | File + DM both present | Pre-check sees DM → skip. Correct | +| Both succeed, network blip on response | File + DM both present, SDK saw error | Same. SDK reconciles via pre-check, not via its own ack | + +The opposite order (DM first, file second) would leave a "ghost DM row pointing at nothing" state that the pre-check couldn't distinguish from a real one. We get lucky here. + +### Operator UI re-upload caveat + +When the operator manually re-uploads a file via the target UI to fill in something the SDK skipped or marked oversize, the UI hits the same non-idempotent endpoint. Because the SDK never created a DM row for that filename in those skipped cases, the UI's create succeeds cleanly — no duplicate. + +The only failure pattern: operator re-uploads a file the SDK already migrated (i.e. a file already on disk and in the DM table). The unique constraint will reject this with a 500. UI surfaces an error; the operator has to delete the existing row first via UI, then re-upload. + +### Optional backend cleanup (out of scope) + +`PromptStudioDocumentHelper.create` could be changed to `get_or_create(tool, document_name)` — eliminates the 500 on re-upload, benefits UI behavior too. **Not required by the SDK** (the pre-check makes the SDK safe regardless) and not blocking; queue as a follow-up if UI ergonomics complaints arrive. + +Document explicitly in the README's "What if files aren't on disk?" and "Files phase specifics" sections. + +## 11. What to push back on if encountered + +- Any request to add a new BE endpoint → confirm with product owner first; the constraint is "no new APIs". +- Any request to change `fetch_contents_ide` shape → same; FE consumes the base64+JSON envelope today. +- Any request to raise `max_file_size` above 50 MB → confirm cloud-worker RAM budget first; default 25 MB exists because base64 round-trip × 2 already runs ~3x file size in peak worker memory. +- Files larger than 25 MB common in the customer's corpus → that's expected to be a tail; operator handles via UI re-upload. If the tail is the majority, escalate — may need to revisit the no-new-endpoints constraint for a streaming download endpoint. +- Any request to copy files directly between storage backends or via signed URLs → ruled out: per-provider CLI maintenance burden, and signed URLs would need new BE endpoints. + +## 12. References + +- `backend/utils/file_storage/helpers/prompt_studio_file_helper.py` — `upload_for_ide`, `fetch_file_contents` (note: `fetch_contents_ide` returns base64-wrapped raw bytes for PDF, not LLMW-extracted text — extracted text lives under `extract/` subdir, written by a different code path) +- `backend/prompt_studio/prompt_studio_core_v2/urls.py` — `prompt_studio_file` route +- `backend/prompt_studio/prompt_studio_document_manager_v2/` — `DocumentManager` model +- Branch: backend changes → `feat/org-migration-platform-api-gaps`; SDK changes → `feat/org-migration` diff --git a/src/unstract/migration/README.md b/src/unstract/migration/README.md new file mode 100644 index 0000000..5d892e0 --- /dev/null +++ b/src/unstract/migration/README.md @@ -0,0 +1,188 @@ +# `unstract.migration` — Org-to-Org Data Migration + +SDK subpackage that lifts an Unstract organization's configured resources from one deployment into another using existing Platform API endpoints. Adapters, connectors, custom tools, prompts, profiles, workflows, tool instances, workflow endpoints, tags, API deployments, pipelines, and Prompt Studio document files. + +## Quickstart + +```bash +UNSTRACT_SRC_PLATFORM_KEY=src_pk_... \ +UNSTRACT_TGT_PLATFORM_KEY=tgt_pk_... \ +uv run python -m unstract.migration migrate \ + --source-url https://us.unstract.com \ + --source-org my-source-org \ + --target-url https://us.unstract.com \ + --target-org my-target-org +``` + +Both keys must be **org admin Platform API keys**. Run from a trusted machine. + +## How it works + +The SDK orchestrates Platform API calls in a strict dependency order. Each phase migrates one resource type and feeds a remap table (`source_uuid → target_uuid`) that later phases consume to rewrite embedded references before POST. + +Phase order: + +``` +1. adapter 7. republish_tool +2. connector 8. files (Prompt Studio document corpus) +3. tag 9. workflow +4. custom_tool 10. tool_instance +5. profile_manager 11. workflow_endpoint +6. prompt 12. api_deployment + 13. pipeline +``` + +Phases 4–7 are composite (run together under `CustomToolPhase`). + +## Failure semantics — important + +### DB writes are committed per-resource, not per-phase + +Each POST is a separate Django request and a separate transaction on the target. There is **no all-or-nothing transaction wrapping a phase**. Consequences: + +- If a phase fails on the Nth entity, entities 1..N-1 are present on target. Entity N is rolled back (its own transaction). Entities N+1..M are never attempted. +- Side-effects that ride on POSTs (API keys auto-minted for API deployments and pipelines, PeriodicTask rows for scheduled pipelines, `DocumentManager` rows for tool documents) are persisted alongside their parent — same per-resource atomicity. +- The in-memory `RemapTable` is process state and is lost on crash. Re-run rebuilds it via Layer 2 idempotency. + +### Re-runs are idempotent and cheap + +**None of the Platform API write endpoints are naturally idempotent** — POSTing the same adapter / connector / workflow / file twice produces two target rows. The SDK works around this with a uniform pattern: **pre-check the target by name before POSTing.** + +Every phase: +- Lists target by name filter (or by parent tool, for files) — one call per phase. +- For each source entity: if already present on target, record `src_uuid → tgt_uuid` in the in-memory remap and skip the POST. If missing, per-id GET on source for the full payload, remap UUIDs, POST. + +The endpoint stays non-idempotent; the SDK guarantees idempotency by **not invoking the endpoint twice**. + +On a clean re-run after a fully-successful migration, no POSTs fire. Cost reduces to one list call per phase per tool. Typical re-run time: 1–2 minutes for a moderate corpus vs. 7–10 minutes for the first run. + +On a re-run after a partial-failure crash, completed phases skip-everything; the crashed phase resumes from the first missing entity. + +#### Files phase specifics + +The upload endpoint (`upload_for_ide`) writes the file to storage **first**, then creates the `DocumentManager` row. Both are unconditional — the endpoint has no upsert. Two consequences: + +- Partial-failure between "file written" and "DM row created" leaves a file with no DM row. The SDK's pre-check looks at DM rows (filenames); seeing no row, it retries the upload, the file is overwritten (storage-idempotent), and the DM row is created. **Self-healing on re-run.** +- Once both succeed, the SDK's pre-check on the next run sees the DM row and skips the upload call entirely. No duplicate DM row. + +The one realistic case where the SDK's pre-check can be defeated is **concurrent UI upload mid-migration**: a user uploading the same filename through the IDE after the SDK already listed target filenames. The SDK then uploads anyway, creating a duplicate DM row. **Mitigation: run migrations in low-activity windows.** + +### Files phase is the exception worth knowing about + +Files are uploaded per-file, one at a time. Each upload is its own request. Failure semantics match the metadata phases (per-file commit, no all-or-nothing). But two extra wrinkles: + +- **Oversize files** (above `--max-file-size`, default 25 MB) are not uploaded; they are recorded in `MigrationReport.oversize_files` and listed at end-of-phase for manual UI re-upload. The run does not abort. +- **If the files phase fails or is skipped**, the target has `CustomTool` + `DocumentManager` rows but no actual files in storage. The platform stays usable globally; per-file operations (preview, index, prompt-run) on missing files error cleanly. Users can re-upload missing files via the target UI's file manager. The platform doesn't crash. + +See [What if files aren't on disk?](#what-if-files-arent-on-disk) for details. + +### How to recover from a mid-failure crash + +1. Read the printed `MigrationReport` — completed phases + the entity that failed. +2. Fix the underlying issue (network, permissions, oversize payload, etc.). +3. Re-run the same command. The SDK picks up where it left off. + +There is no `--resume-from` flag and no state file. The target *is* the state. + +## What gets migrated + +| Resource | Notes | +|----------|-------| +| Adapters | Including decrypted `adapter_metadata` (carries secrets verbatim — same surface the FE already consumes) | +| Connectors | Same secrets posture | +| Tags | Per-org | +| Custom tools | + nested: profile_managers, prompts, document_manager rows | +| Prompt registry | Re-published on target via `update_or_create` (no manual carry) | +| Files | Prompt Studio document corpus per tool — see [Files phase](#files-phase) | +| Workflows | Workflow_name remapped | +| Tool instances | v1 assumes ≤1 per workflow | +| Workflow endpoints | Connector references remapped | +| API deployments | New API key minted on target (consumer keys regenerate; document for downstream consumers) | +| Pipelines | New API key + PeriodicTask auto-minted server-side. Default state: paused (`active=false`) — SDK PATCHes immediately after POST to avoid cron firing on a half-cut-over org | + +## Files phase + +The only resource type with bytes-on-disk that migrate. Storage path on both ends: + +``` +{PERMANENT_REMOTE_STORAGE}/{REMOTE_PROMPT_STUDIO_FILE_PATH}/{org_id}/{user_id}/{tool_id}/ +``` + +Server-generated subdirs (`extract/`, `summarize/`, `converted/`) are **out of scope** — regenerated on target on first index/summarize/preview. + +### Strategy + +Two modes; default is `platform_api`. + +| `--file-strategy` | Behavior | +|-------------------|----------| +| `platform_api` (default) | Download each file via existing `fetch_contents_ide` endpoint, upload via `upload_for_ide`. Cap per file = `--max-file-size` (default 25 MB). Files over cap are reported for manual re-upload, not aborted. Concurrency = 1. | +| `skip` | No bytes touched. `DocumentManager` rows present on target (from `CustomToolPhase`), files missing on disk. Report lists every expected filename for manual UI re-upload. Equivalent: `--skip-files`. | + +### What if files aren't on disk? + +After a `skip`, after oversize-file reporting in `platform_api` mode, or after a mid-failure crash before bytes were transferred, the target has `DocumentManager` rows that reference files not present in storage. **The platform stays usable globally.** Specifically: + +- Tool/workflow/deployment/pipeline listing and navigation: works. +- Opening any CustomTool in Prompt Studio: works. +- Per-file preview pane: errors with a `FileNotFoundError`-derived 500 (no explicit handler in the view today). UI shows an error. +- Index document / run prompt against missing file: errors cleanly (explicit handler). +- Re-upload via UI: works — restores the file. **Caveat:** the upload endpoint is not idempotent at the DB layer; it creates a new `DocumentManager` row unconditionally. If migration already created a DM row for that filename (it does via `CustomToolPhase`), the UI re-upload produces a second DM row pointing at the same (overwritten) file. UI will list the file twice. Delete the stale row via UI first, then re-upload, to avoid duplicates. The SDK itself avoids this trap by pre-checking DM rows before any upload call; the duplicate is purely a UI-side re-upload artifact. +- All other tools/workflows that have their files: unaffected. + +So a partial files migration leaves users able to use the platform broadly; only the specific missing files surface errors when touched. + +## Constraints and trade-offs + +- **No new backend API endpoints** — files phase uses what exists today. The download path eats a ~33% base64 inflation and one-shot full-file memory on both ends. That's why the size cap is conservative. +- **Storage-backend direct copy (`gsutil rsync` / `aws s3 sync`) not supported** — per-provider CLI maintenance burden is too high. +- **No state file** — idempotency relies on target being queryable by name. If you delete a target resource between runs, the SDK recreates it on the next run. +- **No UUID preservation** — every target resource gets a freshly minted UUID. Embedded references are remapped via the in-memory `RemapTable`. + +## Configuration reference + +### Environment + +| Var | Required | Purpose | +|-----|----------|---------| +| `UNSTRACT_SRC_PLATFORM_KEY` | yes | Source org admin Platform API key | +| `UNSTRACT_TGT_PLATFORM_KEY` | yes | Target org admin Platform API key | + +### CLI flags + +| Flag | Default | Purpose | +|------|---------|---------| +| `--source-url` / `--target-url` | — | Base URLs of both deployments | +| `--source-org` / `--target-org` | — | Org slugs | +| `--api-prefix` | `api/v1` | URL prefix; varies on cloud | +| `--include` / `--exclude` | all / none | Phase filter (comma-separated phase names) | +| `--dry-run` | off | List actions, don't POST | +| `--on-name-conflict` | `adopt` | `adopt` (skip existing) or `abort` | +| `--file-strategy` | `platform_api` | `platform_api` or `skip` | +| `--max-file-size` | `25MB` | Per-file cap for files phase | +| `--skip-files` | off | Alias for `--file-strategy=skip` | +| `--pipelines-paused` | on | Toggle the post-POST PATCH that pauses pipelines on target | +| `--verbose` | off | Per-entity log lines | + +## Report shape + +`MigrationReport` exposes: + +- `created` / `adopted` / `failed` counts per phase +- `oversize_files: list[{tool_id, tool_name, file_name, size_bytes, cap_bytes}]` +- `skipped_files: list[{tool_id, tool_name, file_name, source_org_slug, source_tool_id}]` +- `failed_files: list[{tool_id, tool_name, file_name, error}]` +- `remap_snapshot: dict[entity_type, dict[src_uuid, tgt_uuid]]` +- A pretty-printed source-to-target UUID map at end (rich-formatted; plain-text fallback) + +## Logging hygiene + +- Secret values (adapter/connector metadata) are not logged. +- File request/response bodies are not logged. +- Per-entity log lines format: `src= -> tgt=` plus entity name + type. +- Rotate both Platform API keys after the migration completes. + +## Further reading + +- KB: `~/Documents/Obsidian Vault/zipstuff/org-data-migration/` (start with `INDEX.md`) +- Implementation plan for the files phase: `docs/internal/files-migration-plan.md` diff --git a/src/unstract/migration/cli.py b/src/unstract/migration/cli.py index c03311e..0c8339f 100644 --- a/src/unstract/migration/cli.py +++ b/src/unstract/migration/cli.py @@ -9,15 +9,44 @@ from __future__ import annotations import logging +import re import sys from typing import Any import click -from unstract.migration.context import MigrationOptions, OrgEndpoint +from unstract.migration.context import ( + DEFAULT_MAX_FILE_SIZE, + MigrationOptions, + OrgEndpoint, +) from unstract.migration.exceptions import MigrationError from unstract.migration.orchestrator import migrate as run_migrate +_SIZE_UNITS: dict[str, int] = { + "B": 1, + "K": 1024, + "KB": 1024, + "M": 1024 * 1024, + "MB": 1024 * 1024, + "G": 1024 * 1024 * 1024, + "GB": 1024 * 1024 * 1024, +} +_SIZE_RE = re.compile(r"^\s*(\d+(?:\.\d+)?)\s*([A-Za-z]*)\s*$") + + +def _parse_size(value: str) -> int: + """Accept ``25``, ``25MB``, ``1.5GB`` etc. Returns bytes.""" + m = _SIZE_RE.match(value) + if not m: + raise click.BadParameter(f"can't parse size '{value}'") + num, unit = m.group(1), m.group(2).upper() or "B" + if unit not in _SIZE_UNITS: + raise click.BadParameter( + f"unknown size unit '{unit}'; use one of {sorted(_SIZE_UNITS)}" + ) + return int(float(num) * _SIZE_UNITS[unit]) + def _configure_logging(verbose: bool) -> None: level = logging.DEBUG if verbose else logging.INFO @@ -80,6 +109,24 @@ def cli() -> None: show_default=True, help="Backend URL prefix (matches deployment's PATH_PREFIX env)", ) +@click.option( + "--file-strategy", + type=click.Choice(["platform_api", "skip"]), + default="platform_api", + show_default=True, + help="How to move Prompt Studio document files. 'skip' = metadata only.", +) +@click.option( + "--max-file-size", + default="25MB", + show_default=True, + help="Per-file cap for the files phase. Oversize → reported, not aborted.", +) +@click.option( + "--skip-files", + is_flag=True, + help="Alias for --file-strategy=skip.", +) @click.option("-v", "--verbose", is_flag=True, help="Debug logging") def migrate_cmd( source_url: str, @@ -93,17 +140,28 @@ def migrate_cmd( exclude: str | None, on_name_conflict: str, api_prefix: str, + file_strategy: str, + max_file_size: str, + skip_files: bool, verbose: bool, ) -> None: """Migrate configured resources from one org to another.""" _configure_logging(verbose) + effective_strategy = "skip" if skip_files else file_strategy + try: + cap_bytes = _parse_size(max_file_size) + except click.BadParameter as e: + raise click.UsageError(str(e)) from e + options = MigrationOptions( dry_run=dry_run, include=_split_csv(include), exclude=_split_csv(exclude) or (), on_name_conflict=on_name_conflict, verbose=verbose, + file_strategy=effective_strategy, + max_file_size=cap_bytes or DEFAULT_MAX_FILE_SIZE, ) source = OrgEndpoint( diff --git a/src/unstract/migration/client.py b/src/unstract/migration/client.py index 183cefd..7b3bff3 100644 --- a/src/unstract/migration/client.py +++ b/src/unstract/migration/client.py @@ -252,6 +252,54 @@ def sync_prompts( "POST", f"prompt-studio/{tool_id}/sync-prompts/", json=payload ) + def list_prompt_documents(self, tool_id: str) -> list[dict[str, Any]]: + """List DocumentManager rows for a tool. + + Used by FilesPhase for target-side idempotency and source-side + enumeration. Response items carry ``document_id``, + ``document_name``, and ``tool`` (per the serializer's + ``to_representation`` filter). + """ + result = self._request( + "GET", "prompt-document/", params={"tool_id": tool_id} + ) + return result if isinstance(result, list) else result.get("results", []) + + def download_prompt_file( + self, tool_id: str, file_name: str + ) -> dict[str, Any]: + """GET a Prompt Studio document by tool + filename. + + Returns the backend's ``{"data": ..., "mime_type": ...}`` envelope + verbatim. PDFs come back as base64; text/csv as decoded utf-8; + Excel returns a placeholder string (not real bytes) — callers must + treat unsupported mime types as needing manual re-upload. + """ + return self._request( + "GET", + f"prompt-studio/file/{tool_id}", + params={"file_name": file_name}, + ) + + def upload_prompt_file( + self, + tool_id: str, + file_name: str, + data: bytes, + mime_type: str, + ) -> dict[str, Any]: + """Upload a file into a target Prompt Studio tool. + + Backend writes bytes to storage and creates a ``DocumentManager`` + row. The DM model has ``UniqueConstraint(document_name, tool)``, + so callers must pre-check via ``list_prompt_documents`` to avoid + an IntegrityError → 500 on re-runs. + """ + files = {"file": (file_name, data, mime_type)} + return self._request( + "POST", f"prompt-studio/file/{tool_id}/", files=files + ) + def export_custom_tool(self, tool_id: str, *, force: bool = True) -> Any: """Republish ``PromptStudioRegistry`` from the tool's current state. diff --git a/src/unstract/migration/context.py b/src/unstract/migration/context.py index 98d7ce7..c3beec2 100644 --- a/src/unstract/migration/context.py +++ b/src/unstract/migration/context.py @@ -34,6 +34,9 @@ class OrgEndpoint: api_path_prefix: str = "api/v1" +DEFAULT_MAX_FILE_SIZE = 25 * 1024 * 1024 # 25 MB; oversize → manual-upload list + + @dataclass class MigrationOptions: """Per-run flags for ``migrate()``.""" @@ -43,6 +46,10 @@ class MigrationOptions: exclude: tuple[str, ...] = () on_name_conflict: str = "adopt" # "adopt" | "abort" verbose: bool = False + # "platform_api": download/upload via existing endpoints (default). + # "skip": metadata only; operator re-uploads via UI on target. + file_strategy: str = "platform_api" + max_file_size: int = DEFAULT_MAX_FILE_SIZE def includes(self, phase_name: str) -> bool: if self.include is not None and phase_name not in self.include: diff --git a/src/unstract/migration/orchestrator.py b/src/unstract/migration/orchestrator.py index 5701d70..6b5dc8b 100644 --- a/src/unstract/migration/orchestrator.py +++ b/src/unstract/migration/orchestrator.py @@ -21,6 +21,7 @@ APIDeploymentPhase, ConnectorPhase, CustomToolPhase, + FilesPhase, PipelinePhase, TagPhase, ToolInstancePhase, @@ -43,6 +44,7 @@ ("connector", ConnectorPhase), ("tag", TagPhase), ("custom_tool", CustomToolPhase), + ("files", FilesPhase), ("workflow", WorkflowPhase), ("tool_instance", ToolInstancePhase), ("workflow_endpoint", WorkflowEndpointPhase), diff --git a/src/unstract/migration/phases/__init__.py b/src/unstract/migration/phases/__init__.py index bde8030..00e8312 100644 --- a/src/unstract/migration/phases/__init__.py +++ b/src/unstract/migration/phases/__init__.py @@ -12,6 +12,7 @@ from unstract.migration.phases.base import Phase from unstract.migration.phases.connector import ConnectorPhase from unstract.migration.phases.custom_tool import CustomToolPhase +from unstract.migration.phases.files import FilesPhase from unstract.migration.phases.pipeline import PipelinePhase from unstract.migration.phases.tag import TagPhase from unstract.migration.phases.tool_instance import ToolInstancePhase @@ -23,6 +24,7 @@ "AdapterPhase", "ConnectorPhase", "CustomToolPhase", + "FilesPhase", "Phase", "PipelinePhase", "TagPhase", diff --git a/src/unstract/migration/phases/files.py b/src/unstract/migration/phases/files.py new file mode 100644 index 0000000..b76ab28 --- /dev/null +++ b/src/unstract/migration/phases/files.py @@ -0,0 +1,325 @@ +"""Migrate Prompt Studio document files (the user-uploaded test corpus). + +Runs after ``CustomToolPhase`` — consumes the ``custom_tool`` remap to +know which source-tool to target-tool mapping to iterate. + +Default mode (``file_strategy='platform_api'``): + +1. For each ``(src_tool_id, tgt_tool_id)``, list source DM rows + target + DM rows once each. +2. For each source filename missing on target: download from source, decode + per mime, enforce the size cap, upload as multipart to target. +3. Oversize files → ``MigrationReport.oversize_files``; mime types the + backend can't round-trip (Excel placeholder, etc) → + ``unsupported_files``; transport errors → ``failed_files``. + +Skip mode (``file_strategy='skip'``): + +- No download/upload. Source DM list is emitted into ``skipped_files`` so + the operator knows what to re-upload manually via UI. + +Concurrency is 1 per phase by design — the Platform API endpoint holds a +cloud worker for the whole upload, and uploads are not chunked on the BE +helper today. See ``docs/internal/files-migration-plan.md`` for the +sizing rationale. +""" + +from __future__ import annotations + +import base64 +import logging +import time +from typing import Any + +import requests + +from unstract.migration.exceptions import PlatformAPIError +from unstract.migration.phases.base import Phase +from unstract.migration.report import MigrationReport, PhaseResult + +logger = logging.getLogger(__name__) + +# Mime types the BE's fetch_contents_ide endpoint round-trips losslessly. +# PDF → base64; text/plain + text/csv → utf-8 string. Excel and other +# types return a placeholder/unhandled — must be flagged for manual upload. +_BASE64_MIMES: frozenset[str] = frozenset({"application/pdf"}) +_TEXT_MIMES: frozenset[str] = frozenset({"text/plain", "text/csv"}) + +_RETRYABLE_STATUS: frozenset[int] = frozenset({502, 503, 504}) +_MAX_RETRIES = 3 +_RETRY_BACKOFF_BASE_SECONDS = 1.0 + + +class FilesPhase(Phase): + name = "files" + + def run(self, report: MigrationReport) -> PhaseResult: + result = report.get_phase(self.name) + tool_remap = self.ctx.remap.snapshot().get("custom_tool", {}) + if not tool_remap: + logger.info("files phase: no custom_tool remap entries; nothing to do") + return result + + strategy = self.ctx.options.file_strategy + logger.info( + "files phase: strategy=%s tools=%d cap=%d bytes", + strategy, len(tool_remap), self.ctx.options.max_file_size, + ) + + for src_tool_id, tgt_tool_id in tool_remap.items(): + tool_name = self._lookup_tool_name(tgt_tool_id) or src_tool_id + try: + src_docs = self.ctx.source.list_prompt_documents(src_tool_id) + except Exception as e: + logger.exception( + "files: failed to list source DM rows for tool %s: %s", + tool_name, e, + ) + result.failed += 1 + result.errors.append( + f"list source docs {tool_name}: {e}" + ) + continue + + if strategy == "skip": + self._emit_skip(src_docs, src_tool_id, tgt_tool_id, tool_name, report, result) + continue + + self._migrate_tool( + src_tool_id, tgt_tool_id, tool_name, src_docs, report, result + ) + + return result + + def _migrate_tool( + self, + src_tool_id: str, + tgt_tool_id: str, + tool_name: str, + src_docs: list[dict[str, Any]], + report: MigrationReport, + result: PhaseResult, + ) -> None: + try: + tgt_docs = self.ctx.target.list_prompt_documents(tgt_tool_id) + except Exception as e: + logger.exception( + "files: failed to list target DM rows for tool %s: %s", + tool_name, e, + ) + result.failed += 1 + result.errors.append(f"list target docs {tool_name}: {e}") + return + target_names = {d["document_name"] for d in tgt_docs} + + for doc in src_docs: + file_name = doc.get("document_name") + if not file_name: + continue + if file_name in target_names: + result.skipped += 1 + logger.debug( + "files: already present on target tool=%s file=%s", + tool_name, file_name, + ) + continue + if self.ctx.options.dry_run: + result.skipped += 1 + logger.info( + "[dry-run] files: would migrate tool=%s file=%s", + tool_name, file_name, + ) + continue + self._migrate_one_file( + src_tool_id, tgt_tool_id, tool_name, file_name, report, result + ) + + def _migrate_one_file( + self, + src_tool_id: str, + tgt_tool_id: str, + tool_name: str, + file_name: str, + report: MigrationReport, + result: PhaseResult, + ) -> None: + try: + payload = self._with_retry( + lambda: self.ctx.source.download_prompt_file(src_tool_id, file_name), + op=f"download {tool_name}/{file_name}", + ) + except Exception as e: + logger.exception( + "files: download failed tool=%s file=%s: %s", + tool_name, file_name, e, + ) + result.failed += 1 + report.failed_files.append( + { + "tool_id": tgt_tool_id, + "tool_name": tool_name, + "file_name": file_name, + "error": f"download: {e}", + } + ) + return + + mime = (payload or {}).get("mime_type") or "" + raw = self._decode_payload(payload, mime) + if raw is None: + logger.warning( + "files: unsupported mime tool=%s file=%s mime=%s", + tool_name, file_name, mime, + ) + report.unsupported_files.append( + { + "tool_id": tgt_tool_id, + "tool_name": tool_name, + "file_name": file_name, + "mime_type": mime, + } + ) + return + + if len(raw) > self.ctx.options.max_file_size: + report.oversize_files.append( + { + "tool_id": tgt_tool_id, + "tool_name": tool_name, + "file_name": file_name, + "size_bytes": len(raw), + "cap_bytes": self.ctx.options.max_file_size, + } + ) + logger.info( + "files: oversize tool=%s file=%s size=%d cap=%d", + tool_name, file_name, len(raw), self.ctx.options.max_file_size, + ) + return + + try: + self._with_retry( + lambda: self.ctx.target.upload_prompt_file( + tgt_tool_id, file_name, raw, mime + ), + op=f"upload {tool_name}/{file_name}", + ) + except Exception as e: + logger.exception( + "files: upload failed tool=%s file=%s: %s", + tool_name, file_name, e, + ) + result.failed += 1 + report.failed_files.append( + { + "tool_id": tgt_tool_id, + "tool_name": tool_name, + "file_name": file_name, + "error": f"upload: {e}", + } + ) + return + + result.created += 1 + report.uploaded_files.append( + { + "tool_id": tgt_tool_id, + "tool_name": tool_name, + "file_name": file_name, + "size_bytes": len(raw), + "mime_type": mime, + } + ) + logger.info( + "files: uploaded tool=%s file=%s size=%d", + tool_name, file_name, len(raw), + ) + + def _emit_skip( + self, + src_docs: list[dict[str, Any]], + src_tool_id: str, + tgt_tool_id: str, + tool_name: str, + report: MigrationReport, + result: PhaseResult, + ) -> None: + for doc in src_docs: + file_name = doc.get("document_name") + if not file_name: + continue + report.skipped_files.append( + { + "tool_id": tgt_tool_id, + "tool_name": tool_name, + "file_name": file_name, + "source_org_slug": self.ctx.source.endpoint.organization_id, + "source_tool_id": src_tool_id, + } + ) + result.skipped += 1 + logger.info( + "files: skip mode emitted %d filenames for tool=%s", + len(src_docs), tool_name, + ) + + def _decode_payload( + self, payload: dict[str, Any] | None, mime: str + ) -> bytes | None: + if not payload: + return None + data_field = payload.get("data") + if data_field is None: + return None + if mime in _BASE64_MIMES: + # data_field is base64-encoded bytes (BE wraps with b64encode). + if isinstance(data_field, bytes): + return base64.b64decode(data_field) + return base64.b64decode(data_field.encode()) + if mime in _TEXT_MIMES: + if isinstance(data_field, bytes): + return data_field + return data_field.encode("utf-8") + # Excel + unhandled types: BE returned a placeholder string, + # not real bytes. Round-trip would corrupt the file. + return None + + def _lookup_tool_name(self, tgt_tool_id: str) -> str | None: + # CustomToolPhase doesn't record names; fetch lazily for log clarity. + # One call per tool is cheap relative to the per-file traffic. + try: + tools = self.ctx.target.list_custom_tools() + except Exception: + return None + for t in tools: + if t.get("tool_id") == tgt_tool_id: + return t.get("tool_name") + return None + + def _with_retry(self, fn: Any, *, op: str) -> Any: + last_exc: Exception | None = None + for attempt in range(1, _MAX_RETRIES + 1): + try: + return fn() + except PlatformAPIError as e: + last_exc = e + if e.status_code not in _RETRYABLE_STATUS or attempt == _MAX_RETRIES: + raise + sleep = _RETRY_BACKOFF_BASE_SECONDS * (2 ** (attempt - 1)) + logger.warning( + "files: retry %d/%d for %s after %d: sleeping %.1fs", + attempt, _MAX_RETRIES, op, e.status_code, sleep, + ) + time.sleep(sleep) + except (requests.ConnectionError, requests.Timeout) as e: + last_exc = e + if attempt == _MAX_RETRIES: + raise + sleep = _RETRY_BACKOFF_BASE_SECONDS * (2 ** (attempt - 1)) + logger.warning( + "files: retry %d/%d for %s after %s: sleeping %.1fs", + attempt, _MAX_RETRIES, op, type(e).__name__, sleep, + ) + time.sleep(sleep) + assert last_exc is not None + raise last_exc diff --git a/src/unstract/migration/report.py b/src/unstract/migration/report.py index e07ea75..a663ffa 100644 --- a/src/unstract/migration/report.py +++ b/src/unstract/migration/report.py @@ -28,6 +28,13 @@ class MigrationReport: remap_snapshot: dict[str, dict[str, str]] = field(default_factory=dict) aborted: bool = False abort_reason: str | None = None + # Files-phase artifacts. Each entry carries enough context for an + # operator to act on it without cross-referencing the run log. + uploaded_files: list[dict[str, Any]] = field(default_factory=list) + skipped_files: list[dict[str, Any]] = field(default_factory=list) + oversize_files: list[dict[str, Any]] = field(default_factory=list) + unsupported_files: list[dict[str, Any]] = field(default_factory=list) + failed_files: list[dict[str, Any]] = field(default_factory=list) def get_phase(self, name: str) -> PhaseResult: for p in self.phases: @@ -57,6 +64,7 @@ def render(self) -> str: console.print(table) if self.skipped_phases: console.print(f"[dim]Skipped phases:[/dim] {', '.join(self.skipped_phases)}") + self._render_files_sections(console) if self.remap_snapshot: remap = Table(title="Source -> Target UUID Map") remap.add_column("Entity") @@ -80,6 +88,7 @@ def _render_plain(self) -> str: ) if self.skipped_phases: lines.append(f"Skipped phases: {', '.join(self.skipped_phases)}") + lines.extend(self._files_sections_plain()) if self.remap_snapshot: lines.append("") lines.append("Source -> Target UUID Map") @@ -108,4 +117,57 @@ def as_dict(self) -> dict[str, Any]: "remap_snapshot": self.remap_snapshot, "aborted": self.aborted, "abort_reason": self.abort_reason, + "uploaded_files": list(self.uploaded_files), + "skipped_files": list(self.skipped_files), + "oversize_files": list(self.oversize_files), + "unsupported_files": list(self.unsupported_files), + "failed_files": list(self.failed_files), } + + def _render_files_sections(self, console: Any) -> None: + if self.uploaded_files: + console.print( + f"[green]Files uploaded:[/green] {len(self.uploaded_files)}" + ) + for header, rows in ( + ("Oversize files (manual upload required)", self.oversize_files), + ("Unsupported mime files (manual upload required)", self.unsupported_files), + ("Skipped files (operator action required)", self.skipped_files), + ("Failed files", self.failed_files), + ): + if not rows: + continue + console.print(f"[yellow]{header}:[/yellow]") + for row in rows: + console.print(f" - {self._describe_file_row(row)}") + + def _files_sections_plain(self) -> list[str]: + lines: list[str] = [] + if self.uploaded_files: + lines.append(f"Files uploaded: {len(self.uploaded_files)}") + for header, rows in ( + ("Oversize files (manual upload required)", self.oversize_files), + ("Unsupported mime files (manual upload required)", self.unsupported_files), + ("Skipped files (operator action required)", self.skipped_files), + ("Failed files", self.failed_files), + ): + if not rows: + continue + lines.append(f"{header}:") + for row in rows: + lines.append(f" - {self._describe_file_row(row)}") + return lines + + @staticmethod + def _describe_file_row(row: dict[str, Any]) -> str: + tool = row.get("tool_name") or row.get("tool_id") or "?" + name = row.get("file_name", "?") + extras: list[str] = [] + if "size_bytes" in row: + extras.append(f"{row['size_bytes']} bytes") + if "mime_type" in row: + extras.append(row["mime_type"]) + if "error" in row: + extras.append(f"error={row['error']}") + suffix = f" ({', '.join(extras)})" if extras else "" + return f"tool={tool} file={name}{suffix}" diff --git a/tests/migration/test_files_phase.py b/tests/migration/test_files_phase.py new file mode 100644 index 0000000..85d3196 --- /dev/null +++ b/tests/migration/test_files_phase.py @@ -0,0 +1,391 @@ +"""Tests for ``FilesPhase``. + +Coverage: +- happy path: PDF + text/csv files uploaded with base64 + utf-8 decoding. +- target-side idempotency: filename already present → skip, no upload. +- oversize file → ``oversize_files`` entry, sibling files continue. +- unsupported mime (Excel placeholder) → ``unsupported_files`` entry. +- skip strategy → no uploads, source filenames listed in ``skipped_files``. +- dry-run → no uploads even for missing files. +- transient 503 → retried, eventual success. +- no custom_tool remap → no-op. +- listing failure on source aborts only that tool, others continue. +""" + +from __future__ import annotations + +import base64 +from typing import Any + +import pytest + +from unstract.migration.context import ( + MigrationContext, + MigrationOptions, + OrgEndpoint, + RemapTable, +) +from unstract.migration.exceptions import PlatformAPIError +from unstract.migration.phases.files import FilesPhase +from unstract.migration.report import MigrationReport + + +SRC_ENDPOINT = OrgEndpoint( + base_url="http://src", organization_id="src-org", platform_key="src-key" +) +TGT_ENDPOINT = OrgEndpoint( + base_url="http://tgt", organization_id="tgt-org", platform_key="tgt-key" +) + + +class FakeClient: + def __init__( + self, + *, + endpoint: OrgEndpoint, + documents: dict[str, list[dict]] | None = None, + file_payloads: dict[tuple[str, str], dict] | None = None, + tools: list[dict] | None = None, + ): + self.endpoint = endpoint + # tool_id -> list of {document_name, document_id, tool} + self._documents: dict[str, list[dict]] = { + k: list(v) for k, v in (documents or {}).items() + } + # (tool_id, file_name) -> {"data": ..., "mime_type": ...} + self._file_payloads: dict[tuple[str, str], dict] = dict( + file_payloads or {} + ) + self._tools = list(tools or []) + self.uploaded: list[dict[str, Any]] = [] + self.list_calls: list[str] = [] + self.download_calls: list[tuple[str, str]] = [] + # Configurable fault injection. + self.download_errors: dict[tuple[str, str], list[Exception]] = {} + self.upload_errors: dict[tuple[str, str], list[Exception]] = {} + self.list_errors: dict[str, Exception] = {} + self._next_id = 1 + + def list_prompt_documents(self, tool_id: str) -> list[dict]: + self.list_calls.append(tool_id) + if tool_id in self.list_errors: + raise self.list_errors[tool_id] + return [dict(d) for d in self._documents.get(tool_id, [])] + + def download_prompt_file(self, tool_id: str, file_name: str) -> dict: + self.download_calls.append((tool_id, file_name)) + queue = self.download_errors.get((tool_id, file_name)) + if queue: + raise queue.pop(0) + return dict(self._file_payloads[(tool_id, file_name)]) + + def upload_prompt_file( + self, tool_id: str, file_name: str, data: bytes, mime_type: str + ) -> dict: + queue = self.upload_errors.get((tool_id, file_name)) + if queue: + raise queue.pop(0) + doc_id = f"doc-{self._next_id:04d}" + self._next_id += 1 + self.uploaded.append( + { + "tool_id": tool_id, + "file_name": file_name, + "data": data, + "mime_type": mime_type, + } + ) + self._documents.setdefault(tool_id, []).append( + {"document_id": doc_id, "document_name": file_name, "tool": tool_id} + ) + return {"document_id": doc_id} + + def list_custom_tools(self) -> list[dict]: + return list(self._tools) + + +def _ctx(src: FakeClient, tgt: FakeClient, *, remap: RemapTable | None = None, + **opts) -> MigrationContext: + remap = remap or RemapTable() + return MigrationContext( + source=src, + target=tgt, + options=MigrationOptions(**opts), + remap=remap, + ) + + +def _doc(name: str) -> dict: + return {"document_id": f"src-{name}", "document_name": name, "tool": "ignored"} + + +def _pdf_payload(raw: bytes) -> dict: + return {"data": base64.b64encode(raw).decode(), "mime_type": "application/pdf"} + + +def _text_payload(text: str, mime: str = "text/plain") -> dict: + return {"data": text, "mime_type": mime} + + +def test_happy_path_uploads_pdf_and_text(): + src = FakeClient( + endpoint=SRC_ENDPOINT, + documents={"src-1": [_doc("invoice.pdf"), _doc("notes.txt")]}, + file_payloads={ + ("src-1", "invoice.pdf"): _pdf_payload(b"%PDF-FAKE"), + ("src-1", "notes.txt"): _text_payload("hello world"), + }, + ) + tgt = FakeClient(endpoint=TGT_ENDPOINT, tools=[{"tool_id": "tgt-1", "tool_name": "demo"}]) + remap = RemapTable() + remap.record("custom_tool", "src-1", "tgt-1") + ctx = _ctx(src, tgt, remap=remap) + report = MigrationReport() + + result = FilesPhase(ctx).run(report) + + assert result.created == 2 + assert result.failed == 0 + assert {u["file_name"] for u in tgt.uploaded} == {"invoice.pdf", "notes.txt"} + pdf_upload = next(u for u in tgt.uploaded if u["file_name"] == "invoice.pdf") + assert pdf_upload["data"] == b"%PDF-FAKE" + assert pdf_upload["mime_type"] == "application/pdf" + txt_upload = next(u for u in tgt.uploaded if u["file_name"] == "notes.txt") + assert txt_upload["data"] == b"hello world" + assert len(report.uploaded_files) == 2 + assert all(u["tool_name"] == "demo" for u in report.uploaded_files) + + +def test_target_filename_present_is_skipped_no_download(): + src = FakeClient( + endpoint=SRC_ENDPOINT, + documents={"src-1": [_doc("invoice.pdf")]}, + file_payloads={("src-1", "invoice.pdf"): _pdf_payload(b"BYTES")}, + ) + tgt = FakeClient( + endpoint=TGT_ENDPOINT, + documents={"tgt-1": [_doc("invoice.pdf")]}, + tools=[{"tool_id": "tgt-1", "tool_name": "demo"}], + ) + remap = RemapTable() + remap.record("custom_tool", "src-1", "tgt-1") + ctx = _ctx(src, tgt, remap=remap) + report = MigrationReport() + + result = FilesPhase(ctx).run(report) + + assert result.skipped == 1 + assert result.created == 0 + assert tgt.uploaded == [] + assert src.download_calls == [] # pre-check guards the download + + +def test_oversize_file_is_recorded_and_siblings_continue(): + big = b"X" * 50 + src = FakeClient( + endpoint=SRC_ENDPOINT, + documents={"src-1": [_doc("big.pdf"), _doc("small.txt")]}, + file_payloads={ + ("src-1", "big.pdf"): _pdf_payload(big), + ("src-1", "small.txt"): _text_payload("ok"), + }, + ) + tgt = FakeClient(endpoint=TGT_ENDPOINT, tools=[{"tool_id": "tgt-1", "tool_name": "demo"}]) + remap = RemapTable() + remap.record("custom_tool", "src-1", "tgt-1") + ctx = _ctx(src, tgt, remap=remap, max_file_size=10) + report = MigrationReport() + + result = FilesPhase(ctx).run(report) + + assert result.created == 1 + assert {u["file_name"] for u in tgt.uploaded} == {"small.txt"} + assert len(report.oversize_files) == 1 + over = report.oversize_files[0] + assert over["file_name"] == "big.pdf" + assert over["size_bytes"] == 50 + assert over["cap_bytes"] == 10 + + +def test_unsupported_mime_is_recorded_not_uploaded(): + src = FakeClient( + endpoint=SRC_ENDPOINT, + documents={"src-1": [_doc("sheet.xlsx")]}, + file_payloads={ + ("src-1", "sheet.xlsx"): { + "data": "Preview not available for Excel files. ...", + "mime_type": "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", + } + }, + ) + tgt = FakeClient(endpoint=TGT_ENDPOINT, tools=[{"tool_id": "tgt-1", "tool_name": "demo"}]) + remap = RemapTable() + remap.record("custom_tool", "src-1", "tgt-1") + ctx = _ctx(src, tgt, remap=remap) + report = MigrationReport() + + result = FilesPhase(ctx).run(report) + + assert result.created == 0 + assert tgt.uploaded == [] + assert len(report.unsupported_files) == 1 + entry = report.unsupported_files[0] + assert entry["file_name"] == "sheet.xlsx" + assert entry["mime_type"].startswith("application/vnd.openxmlformats") + + +def test_skip_strategy_emits_skipped_files_no_traffic(): + src = FakeClient( + endpoint=SRC_ENDPOINT, + documents={"src-1": [_doc("a.pdf"), _doc("b.pdf")]}, + ) + tgt = FakeClient(endpoint=TGT_ENDPOINT) + remap = RemapTable() + remap.record("custom_tool", "src-1", "tgt-1") + ctx = _ctx(src, tgt, remap=remap, file_strategy="skip") + report = MigrationReport() + + result = FilesPhase(ctx).run(report) + + assert result.skipped == 2 + assert tgt.uploaded == [] + assert src.download_calls == [] + names = {row["file_name"] for row in report.skipped_files} + assert names == {"a.pdf", "b.pdf"} + assert all(row["source_org_slug"] == "src-org" for row in report.skipped_files) + assert all(row["source_tool_id"] == "src-1" for row in report.skipped_files) + + +def test_dry_run_makes_no_writes_even_for_missing_files(): + src = FakeClient( + endpoint=SRC_ENDPOINT, + documents={"src-1": [_doc("a.pdf")]}, + file_payloads={("src-1", "a.pdf"): _pdf_payload(b"X")}, + ) + tgt = FakeClient(endpoint=TGT_ENDPOINT, tools=[{"tool_id": "tgt-1", "tool_name": "demo"}]) + remap = RemapTable() + remap.record("custom_tool", "src-1", "tgt-1") + ctx = _ctx(src, tgt, remap=remap, dry_run=True) + report = MigrationReport() + + result = FilesPhase(ctx).run(report) + + assert result.skipped == 1 + assert result.created == 0 + assert tgt.uploaded == [] + assert src.download_calls == [] + + +def test_transient_503_is_retried_then_succeeds(monkeypatch): + src = FakeClient( + endpoint=SRC_ENDPOINT, + documents={"src-1": [_doc("a.pdf")]}, + file_payloads={("src-1", "a.pdf"): _pdf_payload(b"OK")}, + ) + src.download_errors[("src-1", "a.pdf")] = [ + PlatformAPIError("flaky", status_code=503, body="") + ] + tgt = FakeClient(endpoint=TGT_ENDPOINT, tools=[{"tool_id": "tgt-1", "tool_name": "demo"}]) + remap = RemapTable() + remap.record("custom_tool", "src-1", "tgt-1") + ctx = _ctx(src, tgt, remap=remap) + # Strip the backoff sleep so the test stays fast. + monkeypatch.setattr("unstract.migration.phases.files.time.sleep", lambda *_: None) + report = MigrationReport() + + result = FilesPhase(ctx).run(report) + + assert result.created == 1 + assert tgt.uploaded[0]["data"] == b"OK" + + +def test_no_custom_tool_remap_is_noop(): + src = FakeClient(endpoint=SRC_ENDPOINT) + tgt = FakeClient(endpoint=TGT_ENDPOINT) + ctx = _ctx(src, tgt) + report = MigrationReport() + + result = FilesPhase(ctx).run(report) + + assert result.created == 0 + assert result.skipped == 0 + assert src.list_calls == [] + + +def test_source_list_failure_isolates_to_that_tool(): + src = FakeClient( + endpoint=SRC_ENDPOINT, + documents={"src-2": [_doc("ok.pdf")]}, + file_payloads={("src-2", "ok.pdf"): _pdf_payload(b"OK")}, + ) + src.list_errors["src-1"] = RuntimeError("source down for this tool") + tgt = FakeClient( + endpoint=TGT_ENDPOINT, + tools=[ + {"tool_id": "tgt-1", "tool_name": "broken"}, + {"tool_id": "tgt-2", "tool_name": "healthy"}, + ], + ) + remap = RemapTable() + remap.record("custom_tool", "src-1", "tgt-1") + remap.record("custom_tool", "src-2", "tgt-2") + ctx = _ctx(src, tgt, remap=remap) + report = MigrationReport() + + result = FilesPhase(ctx).run(report) + + assert result.failed == 1 + assert result.created == 1 + assert {u["file_name"] for u in tgt.uploaded} == {"ok.pdf"} + + +def test_upload_failure_records_failed_files_entry(): + src = FakeClient( + endpoint=SRC_ENDPOINT, + documents={"src-1": [_doc("a.pdf")]}, + file_payloads={("src-1", "a.pdf"): _pdf_payload(b"X")}, + ) + tgt = FakeClient(endpoint=TGT_ENDPOINT, tools=[{"tool_id": "tgt-1", "tool_name": "demo"}]) + tgt.upload_errors[("tgt-1", "a.pdf")] = [ + PlatformAPIError("bad", status_code=400, body="bad") + ] + remap = RemapTable() + remap.record("custom_tool", "src-1", "tgt-1") + ctx = _ctx(src, tgt, remap=remap) + report = MigrationReport() + + result = FilesPhase(ctx).run(report) + + assert result.failed == 1 + assert result.created == 0 + assert len(report.failed_files) == 1 + entry = report.failed_files[0] + assert entry["file_name"] == "a.pdf" + assert "upload" in entry["error"] + + +@pytest.mark.parametrize( + "mime,raw", + [ + ("text/csv", "name,age\nalice,30"), + ("text/plain", "plain old text"), + ], +) +def test_text_mimes_round_trip_as_utf8(mime, raw): + src = FakeClient( + endpoint=SRC_ENDPOINT, + documents={"src-1": [_doc("data")]}, + file_payloads={("src-1", "data"): _text_payload(raw, mime=mime)}, + ) + tgt = FakeClient(endpoint=TGT_ENDPOINT, tools=[{"tool_id": "tgt-1", "tool_name": "demo"}]) + remap = RemapTable() + remap.record("custom_tool", "src-1", "tgt-1") + ctx = _ctx(src, tgt, remap=remap) + report = MigrationReport() + + result = FilesPhase(ctx).run(report) + + assert result.created == 1 + upload = tgt.uploaded[0] + assert upload["data"] == raw.encode("utf-8") + assert upload["mime_type"] == mime From fe66b0512f0369816b0b552e8fb9e5f2ef890210 Mon Sep 17 00:00:00 2001 From: Chandrasekharan M Date: Sun, 24 May 2026 16:13:16 +0530 Subject: [PATCH 13/25] fix(migration): files phase end-to-end + report quieting + tool_instance NOT FOUND Multiple fixes uncovered by the first local-stack run: client.py - list_prompt_documents: mount under prompt-studio/ prefix (BE include in urls_v2.py). - download_prompt_file: use ?document_id=, matching fetch_contents_ide serializer (was ?file_name=, BE ignored it and returned 400 ValidationError). - upload_prompt_file: drop trailing slash, BE pattern is prompt-studio/file/ with no slash so POST 404'd. - add get_custom_tool / update_custom_tool for default-doc PATCH. phases/files.py - After upload loop per tool, mirror source's CustomTool.output by filename so FE auto-selects on load. Fall back to first target doc. Preserve any existing target output (operator may have already picked manually on a re-run). phases/tool_instance.py - Detect source serializer sentinels ([X NOT FOUND], [DELETED ADAPTER ...], [NEEDS UPDATE]) in stored metadata and skip the PATCH instead of round-tripping a broken adapter reference. ToolInstance row exists with backend defaults; operator re-binds in UI. report.py - Drop the full source->target UUID map from rendered output (noisy on large migrations). Print per-entity counts only; full map still in as_dict() and at DEBUG log level. Co-Authored-By: Claude Opus 4.7 --- src/unstract/migration/client.py | 35 ++++-- src/unstract/migration/phases/files.py | 116 +++++++++++++++++- .../migration/phases/tool_instance.py | 52 ++++++-- src/unstract/migration/report.py | 40 +++--- tests/migration/test_files_phase.py | 95 +++++++++++++- 5 files changed, 301 insertions(+), 37 deletions(-) diff --git a/src/unstract/migration/client.py b/src/unstract/migration/client.py index 7b3bff3..9f17c96 100644 --- a/src/unstract/migration/client.py +++ b/src/unstract/migration/client.py @@ -169,6 +169,22 @@ def list_custom_tools(self) -> list[dict[str, Any]]: result = self._request("GET", "prompt-studio/") return result if isinstance(result, list) else result.get("results", []) + def get_custom_tool(self, tool_id: str) -> dict[str, Any]: + """Fetch a single prompt-studio project (full serializer). + + Returns ``fields = "__all__"`` per ``CustomToolSerializer`` — + notably includes ``output`` (the default DocumentManager id the + FE binds to ``selectedDoc`` on load). + """ + return self._request("GET", f"prompt-studio/{tool_id}/") + + def update_custom_tool( + self, tool_id: str, body: dict[str, Any] + ) -> dict[str, Any]: + """PATCH a prompt-studio project. Used to set ``output`` (the + default doc id) after the files phase populates DM rows.""" + return self._request("PATCH", f"prompt-studio/{tool_id}/", json=body) + def list_profiles(self, tool_id: str) -> list[dict[str, Any]]: """List ProfileManager rows for a tool. @@ -261,24 +277,25 @@ def list_prompt_documents(self, tool_id: str) -> list[dict[str, Any]]: ``to_representation`` filter). """ result = self._request( - "GET", "prompt-document/", params={"tool_id": tool_id} + "GET", "prompt-studio/prompt-document/", params={"tool_id": tool_id} ) return result if isinstance(result, list) else result.get("results", []) def download_prompt_file( - self, tool_id: str, file_name: str + self, tool_id: str, document_id: str ) -> dict[str, Any]: - """GET a Prompt Studio document by tool + filename. + """GET a Prompt Studio document by tool + DM row id. - Returns the backend's ``{"data": ..., "mime_type": ...}`` envelope - verbatim. PDFs come back as base64; text/csv as decoded utf-8; - Excel returns a placeholder string (not real bytes) — callers must - treat unsupported mime types as needing manual re-upload. + ``fetch_contents_ide`` resolves the filename internally from the + DocumentManager row, so the SDK passes the ``document_id`` it + already has from ``list_prompt_documents`` rather than reposting + the filename. Returns ``{"data": ..., "mime_type": ...}`` — + PDFs base64, text/csv utf-8, Excel placeholder. """ return self._request( "GET", f"prompt-studio/file/{tool_id}", - params={"file_name": file_name}, + params={"document_id": document_id}, ) def upload_prompt_file( @@ -297,7 +314,7 @@ def upload_prompt_file( """ files = {"file": (file_name, data, mime_type)} return self._request( - "POST", f"prompt-studio/file/{tool_id}/", files=files + "POST", f"prompt-studio/file/{tool_id}", files=files ) def export_custom_tool(self, tool_id: str, *, force: bool = True) -> Any: diff --git a/src/unstract/migration/phases/files.py b/src/unstract/migration/phases/files.py index b76ab28..868067c 100644 --- a/src/unstract/migration/phases/files.py +++ b/src/unstract/migration/phases/files.py @@ -114,7 +114,8 @@ def _migrate_tool( for doc in src_docs: file_name = doc.get("document_name") - if not file_name: + src_document_id = doc.get("document_id") + if not file_name or not src_document_id: continue if file_name in target_names: result.skipped += 1 @@ -131,7 +132,18 @@ def _migrate_tool( ) continue self._migrate_one_file( - src_tool_id, tgt_tool_id, tool_name, file_name, report, result + src_tool_id, + tgt_tool_id, + tool_name, + file_name, + src_document_id, + report, + result, + ) + + if not self.ctx.options.dry_run: + self._ensure_default_doc( + src_tool_id, tgt_tool_id, tool_name, src_docs ) def _migrate_one_file( @@ -140,12 +152,15 @@ def _migrate_one_file( tgt_tool_id: str, tool_name: str, file_name: str, + src_document_id: str, report: MigrationReport, result: PhaseResult, ) -> None: try: payload = self._with_retry( - lambda: self.ctx.source.download_prompt_file(src_tool_id, file_name), + lambda: self.ctx.source.download_prompt_file( + src_tool_id, src_document_id + ), op=f"download {tool_name}/{file_name}", ) except Exception as e: @@ -284,6 +299,101 @@ def _decode_payload( # not real bytes. Round-trip would corrupt the file. return None + def _ensure_default_doc( + self, + src_tool_id: str, + tgt_tool_id: str, + tool_name: str, + src_docs: list[dict[str, Any]], + ) -> None: + """Set target ``CustomTool.output`` so the FE auto-selects a doc. + + Mirror source's chosen doc by filename when possible; fall back + to the first available target doc. Skip if target already has + ``output`` set — never override an operator's later choice on + re-runs. + """ + try: + tgt_tool = self.ctx.target.get_custom_tool(tgt_tool_id) + except Exception as e: + logger.warning( + "files: skipping default-doc set for tool=%s — fetch tgt failed: %s", + tool_name, e, + ) + return + + if tgt_tool.get("output"): + logger.debug( + "files: target tool=%s already has default doc; leaving as-is", + tool_name, + ) + return + + try: + tgt_docs = self.ctx.target.list_prompt_documents(tgt_tool_id) + except Exception as e: + logger.warning( + "files: skipping default-doc set for tool=%s — list tgt docs failed: %s", + tool_name, e, + ) + return + if not tgt_docs: + return + + chosen_id = self._pick_default_doc_id( + src_tool_id, src_docs, tgt_docs, tool_name + ) + if not chosen_id: + return + + try: + self.ctx.target.update_custom_tool(tgt_tool_id, {"output": chosen_id}) + logger.info( + "files: set default doc tool=%s doc_id=%s", tool_name, chosen_id + ) + except Exception as e: + logger.warning( + "files: PATCH default doc failed tool=%s: %s", tool_name, e + ) + + def _pick_default_doc_id( + self, + src_tool_id: str, + src_docs: list[dict[str, Any]], + tgt_docs: list[dict[str, Any]], + tool_name: str, + ) -> str | None: + # Try mirroring the source's selection by filename. If source + # GET fails or source has no chosen doc, fall back to the first + # target doc so the FE doesn't render an empty selector. + try: + src_tool = self.ctx.source.get_custom_tool(src_tool_id) + src_output = src_tool.get("output") + except Exception as e: + logger.debug( + "files: source CustomTool fetch failed for tool=%s (%s); " + "falling back to first target doc", + tool_name, e, + ) + src_output = None + + if src_output: + src_name = next( + (d.get("document_name") for d in src_docs + if d.get("document_id") == src_output), + None, + ) + if src_name: + matched = next( + (d.get("document_id") for d in tgt_docs + if d.get("document_name") == src_name), + None, + ) + if matched: + return matched + + return tgt_docs[0].get("document_id") + def _lookup_tool_name(self, tgt_tool_id: str) -> str | None: # CustomToolPhase doesn't record names; fetch lazily for log clarity. # One call per tool is cheap relative to the per-file traffic. diff --git a/src/unstract/migration/phases/tool_instance.py b/src/unstract/migration/phases/tool_instance.py index e4d3cc0..efe5a26 100644 --- a/src/unstract/migration/phases/tool_instance.py +++ b/src/unstract/migration/phases/tool_instance.py @@ -27,6 +27,28 @@ logger = logging.getLogger(__name__) +# Source backend's ToolInstanceSerializer.to_representation emits these +# sentinel strings when an adapter UUID/name in the stored metadata can +# no longer be resolved (deleted or renamed on source). Round-tripping +# them to target produces an AdapterNotFound on PATCH, so we detect and +# skip the metadata PATCH instead — the ToolInstance row exists with the +# backend's safe defaults and the operator can re-bind in the UI. +_BROKEN_ADAPTER_SENTINELS: tuple[str, ...] = ( + "NOT FOUND", + "[DELETED ADAPTER", + "[NEEDS UPDATE]", +) + + +def _broken_adapter_keys(metadata: dict[str, Any]) -> list[str]: + broken: list[str] = [] + for key, value in metadata.items(): + if isinstance(value, str) and any( + s in value for s in _BROKEN_ADAPTER_SENTINELS + ): + broken.append(f"{key}={value!r}") + return broken + class ToolInstancePhase(Phase): name = "tool_instance" @@ -120,14 +142,28 @@ def _migrate_workflow_tools( # PATCH the metadata regardless of created/adopted — keeps tool config # aligned with source on every run. src_metadata = src_ti.get("metadata") or {} - try: - self.ctx.target.update_tool_instance_metadata(tgt_ti["id"], src_metadata) - except Exception as e: - logger.exception( - "Failed to PATCH tool_instance %s metadata: %s", tgt_ti["id"], e + broken = _broken_adapter_keys(src_metadata) + if broken: + logger.warning( + "skipping metadata PATCH for tool_instance src=%s tgt=%s — " + "source metadata carries broken adapter refs %s; " + "row exists with backend defaults, re-bind in UI", + src_ti_id, tgt_ti["id"], broken, ) - result.failed += 1 - result.errors.append(f"patch metadata {tgt_ti['id']}: {e}") - return + result.errors.append( + f"stale adapter refs on src tool_instance {src_ti_id}: {broken}" + ) + else: + try: + self.ctx.target.update_tool_instance_metadata( + tgt_ti["id"], src_metadata + ) + except Exception as e: + logger.exception( + "Failed to PATCH tool_instance %s metadata: %s", tgt_ti["id"], e + ) + result.failed += 1 + result.errors.append(f"patch metadata {tgt_ti['id']}: {e}") + return self.ctx.remap.record("tool_instance", src_ti_id, tgt_ti["id"]) diff --git a/src/unstract/migration/report.py b/src/unstract/migration/report.py index a663ffa..fa06568 100644 --- a/src/unstract/migration/report.py +++ b/src/unstract/migration/report.py @@ -7,9 +7,12 @@ from __future__ import annotations +import logging from dataclasses import dataclass, field from typing import Any +logger = logging.getLogger(__name__) + @dataclass class PhaseResult: @@ -65,15 +68,7 @@ def render(self) -> str: if self.skipped_phases: console.print(f"[dim]Skipped phases:[/dim] {', '.join(self.skipped_phases)}") self._render_files_sections(console) - if self.remap_snapshot: - remap = Table(title="Source -> Target UUID Map") - remap.add_column("Entity") - remap.add_column("Source UUID") - remap.add_column("Target UUID") - for entity, mapping in self.remap_snapshot.items(): - for src, tgt in mapping.items(): - remap.add_row(entity, src, tgt) - console.print(remap) + self._render_remap_summary(console_print=console.print) if self.aborted: console.print(f"[red]ABORTED:[/red] {self.abort_reason}") return buf.getvalue() @@ -89,13 +84,7 @@ def _render_plain(self) -> str: if self.skipped_phases: lines.append(f"Skipped phases: {', '.join(self.skipped_phases)}") lines.extend(self._files_sections_plain()) - if self.remap_snapshot: - lines.append("") - lines.append("Source -> Target UUID Map") - lines.append("-" * 60) - for entity, mapping in self.remap_snapshot.items(): - for src, tgt in mapping.items(): - lines.append(f" {entity:<12} {src} -> {tgt}") + self._render_remap_summary(console_print=lines.append) if self.aborted: lines.append(f"ABORTED: {self.abort_reason}") return "\n".join(lines) @@ -124,6 +113,25 @@ def as_dict(self) -> dict[str, Any]: "failed_files": list(self.failed_files), } + def _render_remap_summary(self, console_print: Any) -> None: + """Summarise the remap snapshot. Full map is large and noisy, so + we only print per-entity counts here; the full mapping is emitted + at DEBUG and remains in ``as_dict()`` for programmatic consumers. + """ + if not self.remap_snapshot: + return + counts = ", ".join( + f"{entity}={len(mapping)}" + for entity, mapping in self.remap_snapshot.items() + if mapping + ) + if counts: + console_print(f"Remap entries: {counts}") + if logger.isEnabledFor(logging.DEBUG): + for entity, mapping in self.remap_snapshot.items(): + for src, tgt in mapping.items(): + logger.debug("remap %s %s -> %s", entity, src, tgt) + def _render_files_sections(self, console: Any) -> None: if self.uploaded_files: console.print( diff --git a/tests/migration/test_files_phase.py b/tests/migration/test_files_phase.py index 85d3196..43744ae 100644 --- a/tests/migration/test_files_phase.py +++ b/tests/migration/test_files_phase.py @@ -72,7 +72,17 @@ def list_prompt_documents(self, tool_id: str) -> list[dict]: raise self.list_errors[tool_id] return [dict(d) for d in self._documents.get(tool_id, [])] - def download_prompt_file(self, tool_id: str, file_name: str) -> dict: + def download_prompt_file(self, tool_id: str, document_id: str) -> dict: + # Tests key payloads + error queues by (tool_id, file_name) for + # readability; resolve the filename from the documents list. + file_name = next( + ( + d["document_name"] + for d in self._documents.get(tool_id, []) + if d.get("document_id") == document_id + ), + document_id, + ) self.download_calls.append((tool_id, file_name)) queue = self.download_errors.get((tool_id, file_name)) if queue: @@ -103,6 +113,16 @@ def upload_prompt_file( def list_custom_tools(self) -> list[dict]: return list(self._tools) + def get_custom_tool(self, tool_id: str) -> dict: + return dict(next((t for t in self._tools if t.get("tool_id") == tool_id), {})) + + def update_custom_tool(self, tool_id: str, body: dict) -> dict: + for t in self._tools: + if t.get("tool_id") == tool_id: + t.update(body) + return dict(t) + return {} + def _ctx(src: FakeClient, tgt: FakeClient, *, remap: RemapTable | None = None, **opts) -> MigrationContext: @@ -389,3 +409,76 @@ def test_text_mimes_round_trip_as_utf8(mime, raw): upload = tgt.uploaded[0] assert upload["data"] == raw.encode("utf-8") assert upload["mime_type"] == mime + + +def test_default_doc_mirrors_source_selection_by_filename(): + src = FakeClient( + endpoint=SRC_ENDPOINT, + documents={"src-1": [_doc("a.pdf"), _doc("b.pdf")]}, + file_payloads={ + ("src-1", "a.pdf"): _pdf_payload(b"A"), + ("src-1", "b.pdf"): _pdf_payload(b"B"), + }, + # Source's selected doc is b.pdf (document_id="src-b.pdf"). + tools=[{"tool_id": "src-1", "tool_name": "demo", "output": "src-b.pdf"}], + ) + tgt = FakeClient( + endpoint=TGT_ENDPOINT, tools=[{"tool_id": "tgt-1", "tool_name": "demo"}] + ) + remap = RemapTable() + remap.record("custom_tool", "src-1", "tgt-1") + ctx = _ctx(src, tgt, remap=remap) + + FilesPhase(ctx).run(MigrationReport()) + + # Target's CustomTool.output now points at b.pdf's new target doc id. + tgt_tool = next(t for t in tgt._tools if t["tool_id"] == "tgt-1") + output_id = tgt_tool["output"] + b_upload = next(d for d in tgt._documents["tgt-1"] if d["document_name"] == "b.pdf") + assert output_id == b_upload["document_id"] + + +def test_default_doc_falls_back_to_first_when_source_has_none(): + src = FakeClient( + endpoint=SRC_ENDPOINT, + documents={"src-1": [_doc("a.pdf")]}, + file_payloads={("src-1", "a.pdf"): _pdf_payload(b"A")}, + # Source has no output set. + tools=[{"tool_id": "src-1", "tool_name": "demo"}], + ) + tgt = FakeClient( + endpoint=TGT_ENDPOINT, tools=[{"tool_id": "tgt-1", "tool_name": "demo"}] + ) + remap = RemapTable() + remap.record("custom_tool", "src-1", "tgt-1") + ctx = _ctx(src, tgt, remap=remap) + + FilesPhase(ctx).run(MigrationReport()) + + tgt_tool = next(t for t in tgt._tools if t["tool_id"] == "tgt-1") + a_upload = next(d for d in tgt._documents["tgt-1"] if d["document_name"] == "a.pdf") + assert tgt_tool["output"] == a_upload["document_id"] + + +def test_default_doc_preserves_existing_target_choice(): + src = FakeClient( + endpoint=SRC_ENDPOINT, + documents={"src-1": [_doc("a.pdf")]}, + file_payloads={("src-1", "a.pdf"): _pdf_payload(b"A")}, + tools=[{"tool_id": "src-1", "tool_name": "demo", "output": "src-a.pdf"}], + ) + # Operator already picked a doc on target — re-run must not clobber. + tgt = FakeClient( + endpoint=TGT_ENDPOINT, + tools=[ + {"tool_id": "tgt-1", "tool_name": "demo", "output": "operator-pick"} + ], + ) + remap = RemapTable() + remap.record("custom_tool", "src-1", "tgt-1") + ctx = _ctx(src, tgt, remap=remap) + + FilesPhase(ctx).run(MigrationReport()) + + tgt_tool = next(t for t in tgt._tools if t["tool_id"] == "tgt-1") + assert tgt_tool["output"] == "operator-pick" From bc4ded6aa5f90666385cafa321cf8ac04b7bd33a Mon Sep 17 00:00:00 2001 From: Chandrasekharan M Date: Sun, 24 May 2026 16:13:32 +0530 Subject: [PATCH 14/25] fix(migration): address greptile P1s on base + workflow_endpoint MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit base.build_post_payload: Previous `value not in (None, "")` dropped booleans False and numeric 0 along with None/"" — DRF BooleanField False and numeric defaults were silently stripped from POST payloads. Switch to explicit identity + equality checks. New test_base_helpers.py guards this. workflow_endpoint._patch_endpoint: When source endpoint had a connector but its remap entry is missing (e.g. connector phase skipped a row), we previously PATCHed the target endpoint with connector_instance_id=None — silently detaching it. Now skip the PATCH, increment result.skipped, and append an error entry so the operator sees the broken link in the report. Existing test rewritten to assert the new skip-and-flag behaviour. Co-Authored-By: Claude Opus 4.7 --- src/unstract/migration/phases/base.py | 10 +++- .../migration/phases/workflow_endpoint.py | 15 ++++- tests/migration/test_base_helpers.py | 58 +++++++++++++++++++ .../migration/test_workflow_endpoint_phase.py | 14 +++-- 4 files changed, 89 insertions(+), 8 deletions(-) create mode 100644 tests/migration/test_base_helpers.py diff --git a/src/unstract/migration/phases/base.py b/src/unstract/migration/phases/base.py index 4b3e0ec..e8dd5a1 100644 --- a/src/unstract/migration/phases/base.py +++ b/src/unstract/migration/phases/base.py @@ -39,7 +39,15 @@ def build_post_payload( and rejects on required fields). """ keys = writable - SERVER_MANAGED - return {k: src[k] for k in keys if k in src and src[k] not in (None, "")} + # Equality with `(None, "")` matched False and 0 too (Python: False == 0, + # 0 in (None, "") is False, but `0 not in (...)` falsely returns True). + # Explicit identity / equality checks preserve falsy-but-meaningful + # values like ``BooleanField`` False and numeric defaults. + return { + k: src[k] + for k in keys + if k in src and src[k] is not None and src[k] != "" + } class Phase(ABC): diff --git a/src/unstract/migration/phases/workflow_endpoint.py b/src/unstract/migration/phases/workflow_endpoint.py index cbdadd3..3e53995 100644 --- a/src/unstract/migration/phases/workflow_endpoint.py +++ b/src/unstract/migration/phases/workflow_endpoint.py @@ -120,10 +120,21 @@ def _patch_endpoint( if src_conn_id: tgt_conn_id = self.ctx.remap.resolve("connector", src_conn_id) if not tgt_conn_id: + # Source had a connector but it never made it through the + # connector phase (e.g. redacted secrets, skipped row). + # Patching the endpoint with connector=None would silently + # detach it on target; skip + flag so the operator notices. logger.warning( - "no connector remap for %s on %s endpoint %s — leaving unset", - src_conn_id, etype, src_ep_id, + "skipping %s endpoint src=%s tgt=%s — source connector %s " + "has no target remap; would silently unset connector", + etype, src_ep_id, tgt_ep_id, src_conn_id, ) + result.skipped += 1 + result.errors.append( + f"unmapped connector on {etype} endpoint {src_ep_id}: " + f"src_connector={src_conn_id}" + ) + return payload: dict[str, Any] = { "connection_type": src_ep.get("connection_type") or "", diff --git a/tests/migration/test_base_helpers.py b/tests/migration/test_base_helpers.py new file mode 100644 index 0000000..ada8b87 --- /dev/null +++ b/tests/migration/test_base_helpers.py @@ -0,0 +1,58 @@ +"""Tests for ``unstract.migration.phases.base`` helpers.""" + +from __future__ import annotations + +from unstract.migration.phases.base import SERVER_MANAGED, build_post_payload + + +def test_preserves_false_and_zero_values(): + """Booleans set to False and numeric 0 are legitimate field values. + + Earlier ``value not in (None, "")`` worked for None/"" but dropped + False and 0 too because of Python's ``False == 0 == in (None, "")`` + edge case. Regression guard. + """ + src = { + "is_active": False, + "retry_count": 0, + "rate_limit": 0.0, + "name": "demo", + } + writable = frozenset({"is_active", "retry_count", "rate_limit", "name"}) + + payload = build_post_payload(src, writable) + + assert payload == { + "is_active": False, + "retry_count": 0, + "rate_limit": 0.0, + "name": "demo", + } + + +def test_strips_none_and_empty_string_but_keeps_zero(): + src = {"a": None, "b": "", "c": 0, "d": False, "e": "kept"} + writable = frozenset({"a", "b", "c", "d", "e"}) + + payload = build_post_payload(src, writable) + + assert payload == {"c": 0, "d": False, "e": "kept"} + + +def test_drops_server_managed_keys_even_if_writable(): + src = {"id": "X", "name": "demo", "organization": "org", "created_by": "u"} + # All four are nominally writable but SERVER_MANAGED should win. + writable = frozenset(src.keys()) + + payload = build_post_payload(src, writable) + + assert payload == {"name": "demo"} + for key in SERVER_MANAGED & set(src.keys()): + assert key not in payload + + +def test_ignores_writable_keys_missing_from_src(): + src = {"present": 1} + writable = frozenset({"present", "absent"}) + + assert build_post_payload(src, writable) == {"present": 1} diff --git a/tests/migration/test_workflow_endpoint_phase.py b/tests/migration/test_workflow_endpoint_phase.py index 2f21f52..161ce92 100644 --- a/tests/migration/test_workflow_endpoint_phase.py +++ b/tests/migration/test_workflow_endpoint_phase.py @@ -155,7 +155,11 @@ def test_endpoint_without_source_connector_patches_with_null(): assert payload["configuration"] == {"foo": "bar"} -def test_unknown_connector_uuid_logs_but_does_not_fail(): +def test_unknown_connector_uuid_skips_endpoint_and_flags_error(): + """Source had a connector but its remap is missing — patching with + connector=None would silently detach the endpoint on target. Skip + the PATCH and record an operator-visible error entry instead. + """ src = FakeClient() src.endpoints[SRC_WF] = [ _src_endpoint( @@ -171,10 +175,10 @@ def test_unknown_connector_uuid_logs_but_does_not_fail(): result = WorkflowEndpointPhase(ctx).run(MigrationReport()) - assert result.created == 1 - # No remap → connector_instance_id stays None instead of failing the PATCH. - _, payload = tgt.patch_calls[0] - assert payload["connector_instance_id"] is None + assert result.created == 0 + assert result.skipped == 1 + assert tgt.patch_calls == [] + assert any("unmapped connector" in e for e in result.errors) def test_missing_target_endpoint_fails_loudly(): From a756d13032edab286a9881d126b987e257f5bf65 Mon Sep 17 00:00:00 2001 From: Chandrasekharan M Date: Sun, 24 May 2026 16:28:28 +0530 Subject: [PATCH 15/25] perf(migration): hoist target list_custom_tools out of per-tool loop MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit list_custom_tools() was called once per source tool — N target-API round-trips for the same invariant data. Fetch once before the loop and append locally on create so adoption lookups stay correct on re-runs. Also drop the value.get("id") fallback in _extract_adapter_name — returning a UUID where the caller expects a name made list_adapters silently miss with a confusing warning. Co-Authored-By: Claude Opus 4.7 --- src/unstract/migration/phases/custom_tool.py | 42 ++++++++++++++------ 1 file changed, 29 insertions(+), 13 deletions(-) diff --git a/src/unstract/migration/phases/custom_tool.py b/src/unstract/migration/phases/custom_tool.py index 5831b52..c7bec8c 100644 --- a/src/unstract/migration/phases/custom_tool.py +++ b/src/unstract/migration/phases/custom_tool.py @@ -45,14 +45,14 @@ def _extract_adapter_name(value: Any) -> str | None: - """ProfileManagerSerializer.to_representation renders adapter FKs as - flat strings holding the adapter NAME (not the UUID). Tolerate the - nested-dict shape too in case serializer behavior diverges. + """Adapter FKs serialise as the adapter NAME on the wire; tolerate a + nested-dict shape too. Never fall back to the UUID — list_adapters + matches by name and would silently miss. """ if isinstance(value, str): return value or None if isinstance(value, dict): - return value.get("adapter_name") or value.get("name") or value.get("id") + return value.get("adapter_name") or value.get("name") return None @@ -70,11 +70,27 @@ def run(self, report: MigrationReport) -> PhaseResult: return result logger.info("Found %d custom tool(s) in source org", len(src_tools)) + # Fetch the target list once — name-based adoption lookup is + # done per source tool, but the underlying list is invariant + # across the loop barring our own creates (which we splice into + # ``target_tools`` after each create so re-runs stay idempotent). + try: + target_tools = self.ctx.target.list_custom_tools() + except Exception as e: + logger.exception("Failed to list target tools: %s", e) + result.failed += 1 + result.errors.append(f"list target tools: {e}") + return for summary in src_tools: - self._migrate_one(summary, result) + self._migrate_one(summary, target_tools, result) return result - def _migrate_one(self, summary: dict[str, Any], result: PhaseResult) -> None: + def _migrate_one( + self, + summary: dict[str, Any], + target_tools: list[dict[str, Any]], + result: PhaseResult, + ) -> None: tool_name = summary["tool_name"] src_tool_id = summary["tool_id"] @@ -86,13 +102,6 @@ def _migrate_one(self, summary: dict[str, Any], result: PhaseResult) -> None: result.errors.append(f"export src tool {tool_name}: {e}") return - try: - target_tools = self.ctx.target.list_custom_tools() - except Exception as e: - logger.exception("Failed to list target tools: %s", e) - result.failed += 1 - result.errors.append(f"list target tools: {e}") - return match = next( (t for t in target_tools if t["tool_name"] == tool_name), None ) @@ -103,6 +112,13 @@ def _migrate_one(self, summary: dict[str, Any], result: PhaseResult) -> None: tgt_tool_id = self._create_fresh( export_data, src_tool_id, tool_name, result ) + # Keep the local cache in sync so a downstream source tool + # with the same name (uncommon but legal) adopts this new + # row instead of trying to re-create it. + if tgt_tool_id is not None: + target_tools.append( + {"tool_id": tgt_tool_id, "tool_name": tool_name} + ) if tgt_tool_id is None: return From 0811d31debd63a6f454c37dd05a43fa00fa43158 Mon Sep 17 00:00:00 2001 From: Chandrasekharan M <117059509+chandrasekharan-zipstack@users.noreply.github.com> Date: Sun, 24 May 2026 17:37:40 +0530 Subject: [PATCH 16/25] Delete docs/internal/files-migration-plan.md Signed-off-by: Chandrasekharan M <117059509+chandrasekharan-zipstack@users.noreply.github.com> --- docs/internal/files-migration-plan.md | 355 -------------------------- 1 file changed, 355 deletions(-) delete mode 100644 docs/internal/files-migration-plan.md diff --git a/docs/internal/files-migration-plan.md b/docs/internal/files-migration-plan.md deleted file mode 100644 index 8590778..0000000 --- a/docs/internal/files-migration-plan.md +++ /dev/null @@ -1,355 +0,0 @@ -# Files-Migration Phase — Implementation Plan - -**Status:** draft, revised 2026-05-24 (post product-owner review). Not a KB doc. Authoritative runbook for the SDK implementor agent (`odm-sdk-impl`) to land the `files` phase end-to-end. - -**Constraints from product owner:** -- **No new backend Platform API endpoints.** Use what exists today. -- **No storage-backend direct copy** (per-provider CLI + credentials handoff is too painful to maintain across customer environments). -- **No signed-URL relay** (would depend on new BE endpoints). -- Backend changes limited to non-API refactors (streaming write helper) that benefit the FE simultaneously. - -**Net scope: Platform-API uploads only. Oversize files are reported, not aborted on — operator handles those via UI re-upload.** - ---- - -## 0. What this phase moves - -Prompt Studio document files only. The user-uploaded test corpus per `CustomTool`. Storage path on both ends: - -``` -{PERMANENT_REMOTE_STORAGE}/{REMOTE_PROMPT_STUDIO_FILE_PATH}/{org_id}/{user_id}/{tool_id}/ -``` - -Server-generated subdirs (`extract/`, `summarize/`, `converted/`) are **out of scope** — regenerated on target on first index/summarize/preview. Don't migrate them. - -`DocumentManager` rows (`prompt_studio_document_manager_v2.DocumentManager`) are the DB-side mirror of the files. **The Platform API upload endpoint creates the `DocumentManager` row as a side effect of the file upload itself** — see §10a. `CustomToolPhase` does not create them. - -## 1. Where the phase sits - -Dependency order: - -``` -… → custom_tool → files → workflow → tool_instance → workflow_endpoint → api_deployment → pipeline -``` - -Runs after `custom_tool` (which mints the `src_tool_id → tgt_tool_id` remap and creates the CustomTool + ProfileManager + prompts on target) and before `workflow` (workflows can reference indexed prompt outputs, but `workflow` creation itself doesn't need the bytes present). - -## 2. Strategy dispatch - -Selectable via `MigrationOptions.file_strategy`: - -| Value | Behavior | -|-------|----------| -| `"platform_api"` (default) | Uploads through existing Platform API endpoints, capped per file. Files over cap are reported, not aborted on. | -| `"skip"` | Operator re-uploads manually via UI on target; SDK does pure metadata migration. | - -**Sizing rationale (US prod data, 2026-05-24):** 18,890 orgs with files, p99=14 files/org, max external=529 (buildwithstudio). Moody's biggest single user-org=99 files. At ~3 sec per file roundtrip, Moody's worst case = ~5 min, buildwithstudio-class = ~25 min. Sequential through the Platform API is in budget. File sizes not in DB — sampling recommended but not blocking; cap is the safety valve. - -## 3. Default upload flow - -### Mechanism - -Use the existing endpoints exactly as they are: - -- **Download from source:** `GET /api/v1//prompt-studio/file//?file_name=` — returns `{"data": , "mime_type": "..."}` per `PromptStudioFileHelper.fetch_file_contents`. -- **Upload to target:** `POST /api/v1//prompt-studio/file//` with multipart body — see `upload_for_ide` signature. -- **List target files for idempotency:** `GET /api/v1//prompt-document/?tool_id=` — returns `[{document_id, document_name, tool}, ...]`; the `document_name` field is what the pre-check matches against. - -### Memory + size discipline - -- **Hard per-file cap.** Default `max_file_size = 25 * 1024 * 1024` (25 MB). Above → skip that file, log to `MigrationReport.oversize_files`, continue siblings. Operator can override via CLI flag (`--max-file-size`). -- **Cap rationale.** 10 MB is too tight (a 50-page scanned PDF can be 15-30 MB); 50 MB risks cloud-worker memory pressure given the base64+JSON-wrap overhead on download (~3x the file size in peak memory). 25 MB covers typical Prompt Studio test docs while staying within a safe worker memory budget. -- **No reliable pre-flight size check.** Existing endpoints don't expose size separately from full payload. Workaround: enforce cap at request time after download — the SDK only knows the size after it has the bytes in memory. Acceptable given the cap is small enough that a single oversize-download spike is bounded. -- **Concurrency = 1** per phase. Do not raise — single-file-at-a-time is what keeps the cloud worker hold bounded. -- **Retry** transient HTTP errors (5xx, ConnectionError, Timeout) with exponential backoff, max 3 attempts. -- **No body logging** in either direction. - -### Idempotency caching (important for re-run cost) - -Fetch the target tool's file list **once per tool**, not per-file. With 529-file corpora (buildwithstudio scale), per-entity name lookups would balloon to 529 list calls just for skip-checks. Pattern: - -```python -def migrate_tool_files(src_tool_id, tgt_tool_id): - tgt_filenames = set(target.list_tool_filenames(tgt_tool_id)) # 1 call - src_files = source.list_tool_filenames(src_tool_id) # 1 call - for fname in src_files: - if fname in tgt_filenames: - report.skipped_existing += 1 - continue - migrate_file(src_tool_id, tgt_tool_id, fname, tgt_filenames) -``` - -This keeps re-run cost at ~2 HTTP calls per tool regardless of file count. Moody's full re-run (10 user-orgs × ~10 tools × 2 calls) ≈ 200 quick HTTP calls ≈ 10-20 sec total for the files phase. - -### Per-file flow - -```python -def migrate_file(src_tool_id, tgt_tool_id, file_name): - # idempotency: skip if name already present on target - if file_name in target_filenames_for(tgt_tool_id): - report.skipped += 1 - return - - # download — full file in memory (existing endpoint constraint) - resp = source.platform_api.get( - f"/{src_org_slug}/prompt-studio/file/{src_tool_id}/", - params={"file_name": file_name}, - ) - payload = resp.json() - mime = payload["mime_type"] - data_field = payload["data"] - - if mime == "application/pdf": - raw = base64.b64decode(data_field) - elif mime in ("text/plain", "text/csv"): - raw = data_field.encode("utf-8") - elif mime.startswith("application/vnd.ms-excel") or mime.startswith("application/vnd.openxmlformats"): - raw = base64.b64decode(data_field) # verify against helper's actual branch - else: - report.warnings.append(f"{file_name}: unknown mime '{mime}', skipping") - return - - if len(raw) > options.max_file_size: - report.oversize_files.append({ - "tool_id": tgt_tool_id, - "tool_name": tgt_tool_name, - "file_name": file_name, - "size_bytes": len(raw), - "cap_bytes": options.max_file_size, - }) - return # oversize → operator handles via UI re-upload - - # upload as multipart - target.platform_api.post( - f"/{tgt_org_slug}/prompt-studio/file/{tgt_tool_id}/", - files={"file": (file_name, raw, mime)}, - ) - report.uploaded += 1 -``` - -Verify mime-branch coverage against the actual `fetch_file_contents` helper (lines 167-188 of `prompt_studio_file_helper.py`) and extend if new branches landed. - -### Cloud safety check - -Before phase starts, log: - -``` -Files phase about to run via Platform API: - source: target: - tools: cap: 25 MB/file - estimated cloud-worker hold: ~1 worker × -``` - -Operator can abort here if running against cloud during peak hours. - -### Idempotency - -Name-based, listed-target side. No hash check. - -## 4. Skip mode - -Pure metadata migration. Phase prints, per migrated tool: - -``` -files: skipped (--skip-files). Documents must be re-uploaded on target via the IDE. - Navigate to each migrated tool on the target deployment and re-upload via the file manager pane. - - - tool 'invoice_extractor' (tgt_id=abc-123): 12 files expected - sample.pdf, contract_q1.pdf, ... - - tool 'receipt_classifier' (tgt_id=def-456): 3 files expected - receipt1.pdf, receipt2.pdf, receipt3.pdf -``` - -**No "destination path" is needed** — the operator clicks "upload" in the target UI's file manager, picks the file from their local disk, and the backend's `upload_for_ide` constructs the storage path from the calling user's session and the target `tool_id`. The operator never touches storage paths. - -**What the operator needs in their hand to do skip-files re-uploads:** the source bytes themselves. Options: -- They already have local copies (they originally uploaded these from their own machine). -- They download from source deployment UI one file at a time before re-uploading to target. Painful at >10-20 files; this is why uploads via Platform API are the default. - -`MigrationReport.skipped_files` carries the full list `[{tool_id, tool_name, file_name, source_org_slug, source_tool_id}, ...]` so external tooling (or a future helper script) can drive a download-then-upload loop using the source deployment's credentials. - -## 4a. Oversize file handling - -When a file exceeds `max_file_size` during the default flow, the SDK does **not** abort. It: - -1. Logs the file under `MigrationReport.oversize_files` with `{tool_id, tool_name, file_name, size_bytes, cap_bytes}`. -2. Continues with sibling files in the same tool. -3. At end-of-phase, prints a "files requiring manual upload" section listing the oversize subset. - -A corpus where 95% of files fit under cap therefore gets 95% auto-migrated and 5% surfaced for manual UI re-upload, in one run. - -## 6. Backend buffering side-quest (separate commit, not blocking the phase) - -The product owner has approved buffering the upload side in the BE because it benefits the FE concurrent-upload path too. **Constraint:** no new endpoints, no behavior change. Just internal refactor. - -### Change - -In `backend/utils/file_storage/helpers/prompt_studio_file_helper.py:upload_for_ide`, replace: - -```python -fs_instance.write( - path=file_path, - mode="wb", - data=file_data if isinstance(file_data, bytes) else file_data.read(), -) -``` - -with chunked streaming when `file_data` is an UploadedFile: - -```python -if isinstance(file_data, bytes): - fs_instance.write(path=file_path, mode="wb", data=file_data) -else: - with fs_instance.open(file_path, mode="wb", block_size=8 * 1024 * 1024) as out: - for chunk in file_data.chunks(chunk_size=8 * 1024 * 1024): - out.write(chunk) -``` - -Apply identical change to `upload_converted_for_ide` (same shape, different path). - -### Regression risks to verify - -1. **`block_size` must be set explicitly** — fsspec defaults vary per provider (GCS, S3, MinIO, Azure). Without it the implementation may buffer to memory until much larger thresholds. -2. **Partial-write cleanup on failure.** Wrap the streaming write in `try`/`except`; on exception, call the underlying multipart-abort if available (`fs.cancel(...)` for GCS resumable; `abort_multipart_upload` for S3 — exposed via `fs.fs.abort_multipart_upload` on s3fs). Document if a provider doesn't support clean abort; aged-out incomplete multipart uploads cost storage money. -3. **MIME detection.** `fetch_file_contents` does `fs.mime_type(...)` which is already partial-read-capable but tests should cover edge file types (Office docs, oddly-headered PDFs). -4. **`isinstance(file_data, bytes)` branch preserved** — some callers pass raw bytes; that path stays single-shot. -5. **Test coverage.** Add streaming-specific tests under `unstract/sdk1/tests/file_storage/` mirroring existing single-shot tests. Confirm peak memory bounded via `tracemalloc` snapshot in test. -6. **No change to `fetch_contents_ide`.** Constraint says no new BE APIs and FE consumes the existing response shape (base64 in JSON). Leave it alone. - -### Out of scope for this commit - -- Download streaming — would change `fetch_contents_ide` contract; product owner ruled out. -- New raw-download endpoint — same constraint. -- Resumability on the upload side — would also require new endpoint surface. - -## 7. SDK package layout - -``` -src/unstract/migration/ - ├── phases/ - │ └── files.py # FilesPhase: upload via Platform API with oversize reporting, or skip - ├── file_transport/ - │ ├── __init__.py - │ └── platform_api.py # download (base64-decode) + upload (multipart) using existing endpoints - └── ... (existing phases unchanged) -``` - -No direct-storage-copy or signed-URL transport modules — those approaches are out of scope. - -## 8. `MigrationOptions` additions - -```python -@dataclass -class MigrationOptions: - # ... existing fields ... - file_strategy: Literal["platform_api", "skip"] = "platform_api" - max_file_size: int = 25 * 1024 * 1024 # 25 MB; oversize files are reported, not uploaded -``` - -CLI flags: - -``` ---file-strategy {platform_api,skip} # default: platform_api ---max-file-size 25MB # accepts human-readable sizes ---skip-files # alias for --file-strategy=skip -``` - -No `auto` dispatch needed — only one byte-moving strategy exists. - -## 9. Test plan - -| Test | Mode | What it proves | -|------|------|----------------| -| 5 MB PDF via Platform API | default | round-trip base64+multipart works against existing helpers | -| 30 MB PDF via Platform API | default | cap fires; entry appears in `oversize_files`; sibling files continue | -| Tool with 10 files (mix of small + 1 oversize) | default | 9 migrated, 1 reported for manual upload, run exits 0 | -| Re-run after success | default | name-based idempotency: target file present → skip | -| `--skip-files` | skip | pure metadata migration; report lists tools + filenames (no storage paths) | -| `MigrationReport.skipped_files` schema | skip | shape matches `[{tool_id, tool_name, file_name, source_org_slug, source_tool_id}]` | -| `MigrationReport.oversize_files` schema | default | shape matches `[{tool_id, tool_name, file_name, size_bytes, cap_bytes}]` | -| Streaming upload backend refactor: 100 MB PDF via UI | (BE) | peak RSS bounded to ~chunk_size; file is byte-identical | -| Streaming upload + network failure mid-stream | (BE) | partial multipart aborted; no orphan upload in bucket | -| Default mode concurrent with FE user uploading | default | both succeed; no worker starvation observable | -| Moody's-scale run (~99 files, all under cap) | default | completes in <10 min; no failures; report counts match input | - -## 10. Acceptance - -- Adapter, connector, tag, custom_tool, prompts, profile_managers, prompt_registry phases unchanged. -- New `files` phase wired into `migrate()` orchestrator after `custom_tool`. -- Local smoke: fresh target org, run default mode against a tool with 3 small PDFs. Re-run: 3 skips, 0 failures. -- Local smoke: same scenario with one 30 MB PDF added → 3 uploaded, 1 in `oversize_files`, exit 0. -- Local smoke: `--skip-files` against same setup → 0 uploaded, all 4 listed in `skipped_files` with tool + filename. -- BE refactor commit (separate from SDK commit) lands on `feat/org-migration-platform-api-gaps`. SDK commit lands on `feat/org-migration`. -- `MigrationReport` exposes: `uploaded_files`, `skipped_files`, `oversize_files`, `failed_files` (each a typed list with tool_id + tool_name + file_name minimum). - -## 10a. Idempotency model — pre-check is load-bearing for correctness - -`upload_for_ide` (`views.py:1009`) is **not idempotent**, and in fact actively errors on retries. Order of operations: - -1. `PromptStudioFileHelper.upload_for_ide(...)` — write file to storage. Overwrites on collision (storage-idempotent). -2. `PromptStudioDocumentHelper.create(tool_id, document_name)` — unconditional `DocumentManager.objects.create`. - -The `DocumentManager` model carries `UniqueConstraint(fields=["document_name", "tool"])`. So calling `upload_for_ide` twice for the same `(tool, filename)` overwrites the file in storage, then raises `IntegrityError` on the second create — propagates as 500. **The SDK MUST pre-check; otherwise re-runs error out.** - -For files specifically, the pattern is: - -```python -def migrate_tool_files(src_tool_id, tgt_tool_id): - tgt_filenames = set(target.list_dm_rows(tgt_tool_id)) - for src_fname in source.list_dm_rows(src_tool_id): - if src_fname in tgt_filenames: - continue - upload_file(src_tool_id, tgt_tool_id, src_fname) -``` - -This guarantees the SDK never invokes the non-idempotent upload twice. The duplicate-DM-row outcome (well, the 500 — see above) only materializes if **something other than the SDK** invokes the upload endpoint with a name the SDK has already migrated. - -### Important difference from earlier draft of this plan - -This plan previously claimed `CustomToolPhase` creates `DocumentManager` rows on target. **It does not.** `import_project` only creates `CustomTool` + `ProfileManager` + prompts; DM rows are only created by `upload_for_ide`. Therefore: - -- Before the files phase runs on a fresh target tool, the DM list is empty. Pre-check returns ∅. SDK uploads all files. ✅ -- On re-run, pre-check returns the SDK's own previously-created rows. SDK skips them. ✅ -- For an oversize / skipped file the operator re-uploads via UI: no SDK DM row exists for that filename → UI creates a single row, no duplicate. ✅ - -The duplicate-DM-row problem from prior plan revisions is **not reachable** through SDK flows. It only manifests if a user uploads via UI mid-migration (between the SDK's list call and its upload call for that same filename) — mitigation is to run migrations in low-activity windows. - -### Crash semantics - -The file-first-then-DM order is fortunate: - -| Failure point | Target state | Re-run outcome | -|---------------|--------------|----------------| -| File write fails | No file, no DM row | Pre-check sees no DM → retry → clean | -| File written, DM create fails | File on disk, no DM row | Pre-check sees no DM → retry → file overwritten, DM created. **Self-healing** | -| Both succeed, SDK dies | File + DM both present | Pre-check sees DM → skip. Correct | -| Both succeed, network blip on response | File + DM both present, SDK saw error | Same. SDK reconciles via pre-check, not via its own ack | - -The opposite order (DM first, file second) would leave a "ghost DM row pointing at nothing" state that the pre-check couldn't distinguish from a real one. We get lucky here. - -### Operator UI re-upload caveat - -When the operator manually re-uploads a file via the target UI to fill in something the SDK skipped or marked oversize, the UI hits the same non-idempotent endpoint. Because the SDK never created a DM row for that filename in those skipped cases, the UI's create succeeds cleanly — no duplicate. - -The only failure pattern: operator re-uploads a file the SDK already migrated (i.e. a file already on disk and in the DM table). The unique constraint will reject this with a 500. UI surfaces an error; the operator has to delete the existing row first via UI, then re-upload. - -### Optional backend cleanup (out of scope) - -`PromptStudioDocumentHelper.create` could be changed to `get_or_create(tool, document_name)` — eliminates the 500 on re-upload, benefits UI behavior too. **Not required by the SDK** (the pre-check makes the SDK safe regardless) and not blocking; queue as a follow-up if UI ergonomics complaints arrive. - -Document explicitly in the README's "What if files aren't on disk?" and "Files phase specifics" sections. - -## 11. What to push back on if encountered - -- Any request to add a new BE endpoint → confirm with product owner first; the constraint is "no new APIs". -- Any request to change `fetch_contents_ide` shape → same; FE consumes the base64+JSON envelope today. -- Any request to raise `max_file_size` above 50 MB → confirm cloud-worker RAM budget first; default 25 MB exists because base64 round-trip × 2 already runs ~3x file size in peak worker memory. -- Files larger than 25 MB common in the customer's corpus → that's expected to be a tail; operator handles via UI re-upload. If the tail is the majority, escalate — may need to revisit the no-new-endpoints constraint for a streaming download endpoint. -- Any request to copy files directly between storage backends or via signed URLs → ruled out: per-provider CLI maintenance burden, and signed URLs would need new BE endpoints. - -## 12. References - -- `backend/utils/file_storage/helpers/prompt_studio_file_helper.py` — `upload_for_ide`, `fetch_file_contents` (note: `fetch_contents_ide` returns base64-wrapped raw bytes for PDF, not LLMW-extracted text — extracted text lives under `extract/` subdir, written by a different code path) -- `backend/prompt_studio/prompt_studio_core_v2/urls.py` — `prompt_studio_file` route -- `backend/prompt_studio/prompt_studio_document_manager_v2/` — `DocumentManager` model -- Branch: backend changes → `feat/org-migration-platform-api-gaps`; SDK changes → `feat/org-migration` From b73556fd6b1ac719714a45547f67da2c1b87191e Mon Sep 17 00:00:00 2001 From: Chandrasekharan M Date: Sun, 24 May 2026 17:58:32 +0530 Subject: [PATCH 17/25] docs(migration): rewrite README for end users --- src/unstract/migration/README.md | 179 +++++++++---------------------- 1 file changed, 53 insertions(+), 126 deletions(-) diff --git a/src/unstract/migration/README.md b/src/unstract/migration/README.md index 5d892e0..b0f03e3 100644 --- a/src/unstract/migration/README.md +++ b/src/unstract/migration/README.md @@ -1,6 +1,8 @@ -# `unstract.migration` — Org-to-Org Data Migration +# Org-to-Org Migration -SDK subpackage that lifts an Unstract organization's configured resources from one deployment into another using existing Platform API endpoints. Adapters, connectors, custom tools, prompts, profiles, workflows, tool instances, workflow endpoints, tags, API deployments, pipelines, and Prompt Studio document files. +Move an Unstract organization's resources from one deployment to another over the Platform API. + +What gets carried over: adapters, connectors, custom tools, prompts, profiles, workflows, tool instances, workflow endpoints, tags, API deployments, pipelines, and Prompt Studio document files. ## Quickstart @@ -8,23 +10,24 @@ SDK subpackage that lifts an Unstract organization's configured resources from o UNSTRACT_SRC_PLATFORM_KEY=src_pk_... \ UNSTRACT_TGT_PLATFORM_KEY=tgt_pk_... \ uv run python -m unstract.migration migrate \ - --source-url https://us.unstract.com \ + --source-url https://source.example.com \ --source-org my-source-org \ - --target-url https://us.unstract.com \ + --target-url https://target.example.com \ --target-org my-target-org ``` -Both keys must be **org admin Platform API keys**. Run from a trusted machine. +You need an **org admin Platform API key** for both ends. -## How it works +> [!WARNING] +> Both keys grant full read on source and full write on target. Run from a trusted machine and **rotate both keys after the migration completes**. -The SDK orchestrates Platform API calls in a strict dependency order. Each phase migrates one resource type and feeds a remap table (`source_uuid → target_uuid`) that later phases consume to rewrite embedded references before POST. +## How it works -Phase order: +The tool walks resources in dependency order. Each phase migrates one type and remembers the new IDs so later phases can rewrite references before posting. ``` 1. adapter 7. republish_tool -2. connector 8. files (Prompt Studio document corpus) +2. connector 8. files (Prompt Studio documents) 3. tag 9. workflow 4. custom_tool 10. tool_instance 5. profile_manager 11. workflow_endpoint @@ -32,114 +35,54 @@ Phase order: 13. pipeline ``` -Phases 4–7 are composite (run together under `CustomToolPhase`). - -## Failure semantics — important - -### DB writes are committed per-resource, not per-phase - -Each POST is a separate Django request and a separate transaction on the target. There is **no all-or-nothing transaction wrapping a phase**. Consequences: - -- If a phase fails on the Nth entity, entities 1..N-1 are present on target. Entity N is rolled back (its own transaction). Entities N+1..M are never attempted. -- Side-effects that ride on POSTs (API keys auto-minted for API deployments and pipelines, PeriodicTask rows for scheduled pipelines, `DocumentManager` rows for tool documents) are persisted alongside their parent — same per-resource atomicity. -- The in-memory `RemapTable` is process state and is lost on crash. Re-run rebuilds it via Layer 2 idempotency. - -### Re-runs are idempotent and cheap - -**None of the Platform API write endpoints are naturally idempotent** — POSTing the same adapter / connector / workflow / file twice produces two target rows. The SDK works around this with a uniform pattern: **pre-check the target by name before POSTing.** - -Every phase: -- Lists target by name filter (or by parent tool, for files) — one call per phase. -- For each source entity: if already present on target, record `src_uuid → tgt_uuid` in the in-memory remap and skip the POST. If missing, per-id GET on source for the full payload, remap UUIDs, POST. - -The endpoint stays non-idempotent; the SDK guarantees idempotency by **not invoking the endpoint twice**. - -On a clean re-run after a fully-successful migration, no POSTs fire. Cost reduces to one list call per phase per tool. Typical re-run time: 1–2 minutes for a moderate corpus vs. 7–10 minutes for the first run. - -On a re-run after a partial-failure crash, completed phases skip-everything; the crashed phase resumes from the first missing entity. - -#### Files phase specifics - -The upload endpoint (`upload_for_ide`) writes the file to storage **first**, then creates the `DocumentManager` row. Both are unconditional — the endpoint has no upsert. Two consequences: +## Re-runs are safe -- Partial-failure between "file written" and "DM row created" leaves a file with no DM row. The SDK's pre-check looks at DM rows (filenames); seeing no row, it retries the upload, the file is overwritten (storage-idempotent), and the DM row is created. **Self-healing on re-run.** -- Once both succeed, the SDK's pre-check on the next run sees the DM row and skips the upload call entirely. No duplicate DM row. +Stop the script mid-run, fix what broke, run the same command again — it picks up where it left off. The tool checks the target by name before creating anything; resources that already exist are reused. -The one realistic case where the SDK's pre-check can be defeated is **concurrent UI upload mid-migration**: a user uploading the same filename through the IDE after the SDK already listed target filenames. The SDK then uploads anyway, creating a duplicate DM row. **Mitigation: run migrations in low-activity windows.** +A clean re-run after a successful migration does no writes and finishes in 1–2 minutes (a first run on a moderate corpus takes 7–10). -### Files phase is the exception worth knowing about +There is no resume flag and no state file. The target *is* the state — if you delete a resource on the target between runs, the next run recreates it. -Files are uploaded per-file, one at a time. Each upload is its own request. Failure semantics match the metadata phases (per-file commit, no all-or-nothing). But two extra wrinkles: +## If something fails partway -- **Oversize files** (above `--max-file-size`, default 25 MB) are not uploaded; they are recorded in `MigrationReport.oversize_files` and listed at end-of-phase for manual UI re-upload. The run does not abort. -- **If the files phase fails or is skipped**, the target has `CustomTool` + `DocumentManager` rows but no actual files in storage. The platform stays usable globally; per-file operations (preview, index, prompt-run) on missing files error cleanly. Users can re-upload missing files via the target UI's file manager. The platform doesn't crash. +Each resource is its own request and its own transaction. There is no all-or-nothing rollback for a phase. -See [What if files aren't on disk?](#what-if-files-arent-on-disk) for details. +1. Read the printed `MigrationReport` — it lists completed phases and the entity that failed. +2. Fix the underlying issue. +3. Re-run the same command. -### How to recover from a mid-failure crash +> [!NOTE] +> API deployments and pipelines get a **new API key minted on the target**. Downstream consumers must be updated with the new key. -1. Read the printed `MigrationReport` — completed phases + the entity that failed. -2. Fix the underlying issue (network, permissions, oversize payload, etc.). -3. Re-run the same command. The SDK picks up where it left off. +> [!NOTE] +> Pipelines are created **paused** on the target so scheduled runs don't fire during cut-over. Unpause them once you're ready. Override with `--no-pipelines-paused`. -There is no `--resume-from` flag and no state file. The target *is* the state. +## Files -## What gets migrated - -| Resource | Notes | -|----------|-------| -| Adapters | Including decrypted `adapter_metadata` (carries secrets verbatim — same surface the FE already consumes) | -| Connectors | Same secrets posture | -| Tags | Per-org | -| Custom tools | + nested: profile_managers, prompts, document_manager rows | -| Prompt registry | Re-published on target via `update_or_create` (no manual carry) | -| Files | Prompt Studio document corpus per tool — see [Files phase](#files-phase) | -| Workflows | Workflow_name remapped | -| Tool instances | v1 assumes ≤1 per workflow | -| Workflow endpoints | Connector references remapped | -| API deployments | New API key minted on target (consumer keys regenerate; document for downstream consumers) | -| Pipelines | New API key + PeriodicTask auto-minted server-side. Default state: paused (`active=false`) — SDK PATCHes immediately after POST to avoid cron firing on a half-cut-over org | - -## Files phase - -The only resource type with bytes-on-disk that migrate. Storage path on both ends: - -``` -{PERMANENT_REMOTE_STORAGE}/{REMOTE_PROMPT_STUDIO_FILE_PATH}/{org_id}/{user_id}/{tool_id}/ -``` - -Server-generated subdirs (`extract/`, `summarize/`, `converted/`) are **out of scope** — regenerated on target on first index/summarize/preview. - -### Strategy - -Two modes; default is `platform_api`. +The Prompt Studio document corpus is the only thing with actual bytes on disk. Default strategy downloads each file from source and uploads to target, one at a time, capped at 25 MB per file by default. | `--file-strategy` | Behavior | |-------------------|----------| -| `platform_api` (default) | Download each file via existing `fetch_contents_ide` endpoint, upload via `upload_for_ide`. Cap per file = `--max-file-size` (default 25 MB). Files over cap are reported for manual re-upload, not aborted. Concurrency = 1. | -| `skip` | No bytes touched. `DocumentManager` rows present on target (from `CustomToolPhase`), files missing on disk. Report lists every expected filename for manual UI re-upload. Equivalent: `--skip-files`. | - -### What if files aren't on disk? +| `platform_api` (default) | Transfer each file via the Platform API. Files over `--max-file-size` are skipped and listed at the end for manual re-upload. | +| `skip` | Don't transfer any files. Document records are still created on the target. Equivalent to `--skip-files`. | -After a `skip`, after oversize-file reporting in `platform_api` mode, or after a mid-failure crash before bytes were transferred, the target has `DocumentManager` rows that reference files not present in storage. **The platform stays usable globally.** Specifically: +> [!WARNING] +> If you run migrations while users are actively uploading to the same source org, you can end up with duplicate file records on the target. **Run migrations in low-activity windows.** -- Tool/workflow/deployment/pipeline listing and navigation: works. -- Opening any CustomTool in Prompt Studio: works. -- Per-file preview pane: errors with a `FileNotFoundError`-derived 500 (no explicit handler in the view today). UI shows an error. -- Index document / run prompt against missing file: errors cleanly (explicit handler). -- Re-upload via UI: works — restores the file. **Caveat:** the upload endpoint is not idempotent at the DB layer; it creates a new `DocumentManager` row unconditionally. If migration already created a DM row for that filename (it does via `CustomToolPhase`), the UI re-upload produces a second DM row pointing at the same (overwritten) file. UI will list the file twice. Delete the stale row via UI first, then re-upload, to avoid duplicates. The SDK itself avoids this trap by pre-checking DM rows before any upload call; the duplicate is purely a UI-side re-upload artifact. -- All other tools/workflows that have their files: unaffected. +> [!NOTE] +> If a file is missing on disk (skipped, oversize, or a mid-run crash), the platform stays usable. Only operations that touch that specific file (preview, index, prompt run) will error. Re-upload missing files through the UI. -So a partial files migration leaves users able to use the platform broadly; only the specific missing files surface errors when touched. +## What you'll see in the report -## Constraints and trade-offs +`MigrationReport` prints at the end with: -- **No new backend API endpoints** — files phase uses what exists today. The download path eats a ~33% base64 inflation and one-shot full-file memory on both ends. That's why the size cap is conservative. -- **Storage-backend direct copy (`gsutil rsync` / `aws s3 sync`) not supported** — per-provider CLI maintenance burden is too high. -- **No state file** — idempotency relies on target being queryable by name. If you delete a target resource between runs, the SDK recreates it on the next run. -- **No UUID preservation** — every target resource gets a freshly minted UUID. Embedded references are remapped via the in-memory `RemapTable`. +- Per-phase counts: `created`, `adopted` (already existed), `failed` +- `oversize_files` — files skipped because they exceeded the cap +- `skipped_files` — files not transferred under `--file-strategy=skip` +- `failed_files` — files the upload itself failed on +- A source-to-target UUID map for every migrated resource -## Configuration reference +## CLI reference ### Environment @@ -148,41 +91,25 @@ So a partial files migration leaves users able to use the platform broadly; only | `UNSTRACT_SRC_PLATFORM_KEY` | yes | Source org admin Platform API key | | `UNSTRACT_TGT_PLATFORM_KEY` | yes | Target org admin Platform API key | -### CLI flags +### Flags | Flag | Default | Purpose | |------|---------|---------| | `--source-url` / `--target-url` | — | Base URLs of both deployments | | `--source-org` / `--target-org` | — | Org slugs | -| `--api-prefix` | `api/v1` | URL prefix; varies on cloud | -| `--include` / `--exclude` | all / none | Phase filter (comma-separated phase names) | -| `--dry-run` | off | List actions, don't POST | -| `--on-name-conflict` | `adopt` | `adopt` (skip existing) or `abort` | +| `--api-prefix` | `api/v1` | URL prefix for the Platform API | +| `--include` / `--exclude` | all / none | Comma-separated phase names | +| `--dry-run` | off | List actions without writing | +| `--on-name-conflict` | `adopt` | `adopt` reuses existing target resources; `abort` stops on conflict | | `--file-strategy` | `platform_api` | `platform_api` or `skip` | -| `--max-file-size` | `25MB` | Per-file cap for files phase | +| `--max-file-size` | `25MB` | Per-file cap | | `--skip-files` | off | Alias for `--file-strategy=skip` | -| `--pipelines-paused` | on | Toggle the post-POST PATCH that pauses pipelines on target | +| `--pipelines-paused` | on | Create pipelines paused on target | | `--verbose` | off | Per-entity log lines | -## Report shape - -`MigrationReport` exposes: - -- `created` / `adopted` / `failed` counts per phase -- `oversize_files: list[{tool_id, tool_name, file_name, size_bytes, cap_bytes}]` -- `skipped_files: list[{tool_id, tool_name, file_name, source_org_slug, source_tool_id}]` -- `failed_files: list[{tool_id, tool_name, file_name, error}]` -- `remap_snapshot: dict[entity_type, dict[src_uuid, tgt_uuid]]` -- A pretty-printed source-to-target UUID map at end (rich-formatted; plain-text fallback) - -## Logging hygiene - -- Secret values (adapter/connector metadata) are not logged. -- File request/response bodies are not logged. -- Per-entity log lines format: `src= -> tgt=` plus entity name + type. -- Rotate both Platform API keys after the migration completes. - -## Further reading +## Things to keep in mind -- KB: `~/Documents/Obsidian Vault/zipstuff/org-data-migration/` (start with `INDEX.md`) -- Implementation plan for the files phase: `docs/internal/files-migration-plan.md` +- **Adapter and connector secrets are carried verbatim.** They never appear in logs, but they do travel over the wire to the target — both deployments must be ones you trust. +- **UUIDs are not preserved.** Every target resource gets a fresh UUID. References between resources are rewritten automatically. +- **Direct storage-bucket copy isn't supported.** Files always go through the Platform API. +- **Run from a trusted machine.** Both API keys are loaded as environment variables. From d2cc32fdb238623b41149df61eaf7bddb57eca6d Mon Sep 17 00:00:00 2001 From: Chandrasekharan M Date: Sun, 24 May 2026 18:05:57 +0530 Subject: [PATCH 18/25] docs(migration): add sample report + cross-deployment note --- src/unstract/migration/README.md | 41 +++++++++++++++++++++++++++----- 1 file changed, 35 insertions(+), 6 deletions(-) diff --git a/src/unstract/migration/README.md b/src/unstract/migration/README.md index b0f03e3..841676d 100644 --- a/src/unstract/migration/README.md +++ b/src/unstract/migration/README.md @@ -2,6 +2,8 @@ Move an Unstract organization's resources from one deployment to another over the Platform API. +Source and target can be the same instance (org-to-org on one deployment) or two different instances (one URL to another). The only requirement is that both expose the Platform API and you hold an admin key on each. + What gets carried over: adapters, connectors, custom tools, prompts, profiles, workflows, tool instances, workflow endpoints, tags, API deployments, pipelines, and Prompt Studio document files. ## Quickstart @@ -74,13 +76,40 @@ The Prompt Studio document corpus is the only thing with actual bytes on disk. D ## What you'll see in the report -`MigrationReport` prints at the end with: +At the end of every run, a `MigrationReport` is printed. Per-phase counts up top, then any files that need follow-up, then a remap summary, then a status footer. + +Counts per phase: `Created` (new on target), `Adopted` (already existed, reused), `Skipped`, `Failed`. + +``` + Migration Report +┏━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━┓ +┃ Phase ┃ Created ┃ Adopted ┃ Skipped ┃ Failed ┃ +┡━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━┩ +│ adapter │ 6 │ 2 │ 0 │ 0 │ +│ connector │ 3 │ 0 │ 0 │ 0 │ +│ tag │ 4 │ 0 │ 0 │ 0 │ +│ custom_tool │ 12 │ 1 │ 0 │ 0 │ +│ files │ 87 │ 0 │ 3 │ 1 │ +│ workflow │ 5 │ 0 │ 0 │ 0 │ +│ tool_instance │ 5 │ 0 │ 0 │ 0 │ +│ workflow_endpoint │ 10 │ 0 │ 0 │ 0 │ +│ api_deployment │ 2 │ 0 │ 0 │ 0 │ +│ pipeline │ 1 │ 0 │ 0 │ 0 │ +├──────────────────────┼─────────┼─────────┼─────────┼────────┤ +│ TOTAL │ 135 │ 3 │ 3 │ 1 │ +└──────────────────────┴─────────┴─────────┴─────────┴────────┘ +Files uploaded: 87 +Oversize files (manual upload required): + - tool=invoice-extractor file=scan-2023-archive.pdf size=41.2MB cap=25.0MB +Failed files: + - tool=contracts file=draft.docx error=upload timed out +Remap entries: adapter=8, connector=3, tag=4, custom_tool=13, workflow=5, ... +Completed with 1 failure(s) — see WARNING/ERROR log lines above for details +``` + +A clean run ends with `Completed successfully`. A run aborted by `--on-name-conflict=abort` ends with `ABORTED: `. -- Per-phase counts: `created`, `adopted` (already existed), `failed` -- `oversize_files` — files skipped because they exceeded the cap -- `skipped_files` — files not transferred under `--file-strategy=skip` -- `failed_files` — files the upload itself failed on -- A source-to-target UUID map for every migrated resource +The same data is available programmatically via `report.as_dict()` — useful if you're wrapping the command in your own automation. ## CLI reference From d8e64909afc7a38b50ef267ee531524cb7bf05c8 Mon Sep 17 00:00:00 2001 From: Chandrasekharan M Date: Sun, 24 May 2026 18:16:58 +0530 Subject: [PATCH 19/25] docs(migration): slim README, defer to public docs --- src/unstract/migration/README.md | 126 +++---------------------------- 1 file changed, 10 insertions(+), 116 deletions(-) diff --git a/src/unstract/migration/README.md b/src/unstract/migration/README.md index 841676d..9c2a69c 100644 --- a/src/unstract/migration/README.md +++ b/src/unstract/migration/README.md @@ -1,10 +1,11 @@ # Org-to-Org Migration -Move an Unstract organization's resources from one deployment to another over the Platform API. +Move an Unstract organization's configured resources from one deployment to another (or between two orgs on the same deployment). -Source and target can be the same instance (org-to-org on one deployment) or two different instances (one URL to another). The only requirement is that both expose the Platform API and you hold an admin key on each. +Carried over: adapters, connectors, custom tools, prompts, profiles, workflows, tool instances, workflow endpoints, tags, API deployments, pipelines, and Prompt Studio document files. -What gets carried over: adapters, connectors, custom tools, prompts, profiles, workflows, tool instances, workflow endpoints, tags, API deployments, pipelines, and Prompt Studio document files. +> **Full documentation, behavior notes, CLI reference, and sample report:** +> https://docs.unstract.com/unstract/unstract_platform/api_documentation/versions/v1-org-migration/ ## Quickstart @@ -18,127 +19,20 @@ uv run python -m unstract.migration migrate \ --target-org my-target-org ``` -You need an **org admin Platform API key** for both ends. +Both keys must be **org admin Platform API keys**. > [!WARNING] -> Both keys grant full read on source and full write on target. Run from a trusted machine and **rotate both keys after the migration completes**. - -## How it works - -The tool walks resources in dependency order. Each phase migrates one type and remembers the new IDs so later phases can rewrite references before posting. - -``` -1. adapter 7. republish_tool -2. connector 8. files (Prompt Studio documents) -3. tag 9. workflow -4. custom_tool 10. tool_instance -5. profile_manager 11. workflow_endpoint -6. prompt 12. api_deployment - 13. pipeline -``` +> Both keys grant broad access. Run from a trusted machine and rotate both keys after the migration completes. ## Re-runs are safe -Stop the script mid-run, fix what broke, run the same command again — it picks up where it left off. The tool checks the target by name before creating anything; resources that already exist are reused. - -A clean re-run after a successful migration does no writes and finishes in 1–2 minutes (a first run on a moderate corpus takes 7–10). - -There is no resume flag and no state file. The target *is* the state — if you delete a resource on the target between runs, the next run recreates it. - -## If something fails partway - -Each resource is its own request and its own transaction. There is no all-or-nothing rollback for a phase. - -1. Read the printed `MigrationReport` — it lists completed phases and the entity that failed. -2. Fix the underlying issue. -3. Re-run the same command. - -> [!NOTE] -> API deployments and pipelines get a **new API key minted on the target**. Downstream consumers must be updated with the new key. - -> [!NOTE] -> Pipelines are created **paused** on the target so scheduled runs don't fire during cut-over. Unpause them once you're ready. Override with `--no-pipelines-paused`. +If a phase fails partway, fix the cause and re-run the same command. Resources already on the target are detected by name and reused. There is no `--resume-from` flag — the target is the state. ## Files -The Prompt Studio document corpus is the only thing with actual bytes on disk. Default strategy downloads each file from source and uploads to target, one at a time, capped at 25 MB per file by default. - -| `--file-strategy` | Behavior | -|-------------------|----------| -| `platform_api` (default) | Transfer each file via the Platform API. Files over `--max-file-size` are skipped and listed at the end for manual re-upload. | -| `skip` | Don't transfer any files. Document records are still created on the target. Equivalent to `--skip-files`. | +The Prompt Studio document corpus is the only resource type with bytes on disk. Default cap per file is 25 MB; oversize files are reported for manual re-upload. Use `--skip-files` to skip bytes entirely (document records are still created). > [!WARNING] -> If you run migrations while users are actively uploading to the same source org, you can end up with duplicate file records on the target. **Run migrations in low-activity windows.** - -> [!NOTE] -> If a file is missing on disk (skipped, oversize, or a mid-run crash), the platform stays usable. Only operations that touch that specific file (preview, index, prompt run) will error. Re-upload missing files through the UI. - -## What you'll see in the report - -At the end of every run, a `MigrationReport` is printed. Per-phase counts up top, then any files that need follow-up, then a remap summary, then a status footer. - -Counts per phase: `Created` (new on target), `Adopted` (already existed, reused), `Skipped`, `Failed`. - -``` - Migration Report -┏━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━┓ -┃ Phase ┃ Created ┃ Adopted ┃ Skipped ┃ Failed ┃ -┡━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━┩ -│ adapter │ 6 │ 2 │ 0 │ 0 │ -│ connector │ 3 │ 0 │ 0 │ 0 │ -│ tag │ 4 │ 0 │ 0 │ 0 │ -│ custom_tool │ 12 │ 1 │ 0 │ 0 │ -│ files │ 87 │ 0 │ 3 │ 1 │ -│ workflow │ 5 │ 0 │ 0 │ 0 │ -│ tool_instance │ 5 │ 0 │ 0 │ 0 │ -│ workflow_endpoint │ 10 │ 0 │ 0 │ 0 │ -│ api_deployment │ 2 │ 0 │ 0 │ 0 │ -│ pipeline │ 1 │ 0 │ 0 │ 0 │ -├──────────────────────┼─────────┼─────────┼─────────┼────────┤ -│ TOTAL │ 135 │ 3 │ 3 │ 1 │ -└──────────────────────┴─────────┴─────────┴─────────┴────────┘ -Files uploaded: 87 -Oversize files (manual upload required): - - tool=invoice-extractor file=scan-2023-archive.pdf size=41.2MB cap=25.0MB -Failed files: - - tool=contracts file=draft.docx error=upload timed out -Remap entries: adapter=8, connector=3, tag=4, custom_tool=13, workflow=5, ... -Completed with 1 failure(s) — see WARNING/ERROR log lines above for details -``` - -A clean run ends with `Completed successfully`. A run aborted by `--on-name-conflict=abort` ends with `ABORTED: `. - -The same data is available programmatically via `report.as_dict()` — useful if you're wrapping the command in your own automation. - -## CLI reference - -### Environment - -| Var | Required | Purpose | -|-----|----------|---------| -| `UNSTRACT_SRC_PLATFORM_KEY` | yes | Source org admin Platform API key | -| `UNSTRACT_TGT_PLATFORM_KEY` | yes | Target org admin Platform API key | - -### Flags - -| Flag | Default | Purpose | -|------|---------|---------| -| `--source-url` / `--target-url` | — | Base URLs of both deployments | -| `--source-org` / `--target-org` | — | Org slugs | -| `--api-prefix` | `api/v1` | URL prefix for the Platform API | -| `--include` / `--exclude` | all / none | Comma-separated phase names | -| `--dry-run` | off | List actions without writing | -| `--on-name-conflict` | `adopt` | `adopt` reuses existing target resources; `abort` stops on conflict | -| `--file-strategy` | `platform_api` | `platform_api` or `skip` | -| `--max-file-size` | `25MB` | Per-file cap | -| `--skip-files` | off | Alias for `--file-strategy=skip` | -| `--pipelines-paused` | on | Create pipelines paused on target | -| `--verbose` | off | Per-entity log lines | - -## Things to keep in mind +> Run migrations during low-activity windows. Concurrent uploads to the source org during a migration can create duplicate file records on the target. -- **Adapter and connector secrets are carried verbatim.** They never appear in logs, but they do travel over the wire to the target — both deployments must be ones you trust. -- **UUIDs are not preserved.** Every target resource gets a fresh UUID. References between resources are rewritten automatically. -- **Direct storage-bucket copy isn't supported.** Files always go through the Platform API. -- **Run from a trusted machine.** Both API keys are loaded as environment variables. +See the [public docs](https://docs.unstract.com/unstract/unstract_platform/api_documentation/versions/v1-org-migration/) for the full flag list, behavioral notes, and the format of the end-of-run report. From ca1b6dbc5fa139b2e9ff76ae4f01c51a3715d383 Mon Sep 17 00:00:00 2001 From: Chandrasekharan M Date: Sun, 24 May 2026 18:49:46 +0530 Subject: [PATCH 20/25] refactor(clone): rename migration module/CLI to clone - src/unstract/migration -> src/unstract/clone, tests/migration -> tests/clone - MigrationReport/Context/Options/Error -> Clone* - CLI subcommand 'migrate' -> 'clone', script 'unstract-migrate' -> 'unstract-clone' - pyproject extra 'migration' -> 'clone' - README + docstrings + log lines updated; report title 'Migration Report' -> 'Clone Report' All 75 unit tests pass. --- pyproject.toml | 4 +- src/unstract/{migration => clone}/README.md | 16 ++-- src/unstract/{migration => clone}/__init__.py | 20 ++-- src/unstract/clone/__main__.py | 6 ++ src/unstract/{migration => clone}/cli.py | 30 +++--- src/unstract/{migration => clone}/client.py | 8 +- src/unstract/{migration => clone}/context.py | 20 ++-- .../{migration => clone}/exceptions.py | 12 +-- .../{migration => clone}/orchestrator.py | 37 ++++---- src/unstract/clone/phases/__init__.py | 34 +++++++ .../{migration => clone}/phases/adapter.py | 14 +-- .../phases/api_deployment.py | 16 ++-- .../{migration => clone}/phases/base.py | 10 +- .../{migration => clone}/phases/connector.py | 12 +-- .../phases/custom_tool.py | 16 ++-- .../{migration => clone}/phases/files.py | 28 +++--- .../{migration => clone}/phases/pipeline.py | 16 ++-- .../{migration => clone}/phases/tag.py | 14 +-- .../phases/tool_instance.py | 29 ++++-- .../{migration => clone}/phases/workflow.py | 16 ++-- .../phases/workflow_endpoint.py | 12 +-- src/unstract/{migration => clone}/report.py | 92 +++++++++++++++++-- src/unstract/{migration => clone}/walker.py | 2 +- src/unstract/migration/__main__.py | 6 -- src/unstract/migration/phases/__init__.py | 34 ------- tests/{migration => clone}/__init__.py | 0 .../test_adapter_phase.py | 24 ++--- .../test_api_deployment_phase.py | 30 +++--- .../{migration => clone}/test_base_helpers.py | 4 +- .../test_connector_phase.py | 26 +++--- .../test_custom_tool_phase.py | 30 +++--- .../{migration => clone}/test_files_phase.py | 48 +++++----- .../test_pipeline_phase.py | 34 +++---- .../{migration => clone}/test_remap_table.py | 2 +- tests/{migration => clone}/test_tag_phase.py | 24 ++--- .../test_tool_instance_phase.py | 43 +++++---- tests/{migration => clone}/test_walker.py | 4 +- .../test_workflow_endpoint_phase.py | 26 +++--- .../test_workflow_phase.py | 24 ++--- uv.lock | 8 +- 40 files changed, 466 insertions(+), 365 deletions(-) rename src/unstract/{migration => clone}/README.md (55%) rename src/unstract/{migration => clone}/__init__.py (52%) create mode 100644 src/unstract/clone/__main__.py rename src/unstract/{migration => clone}/cli.py (87%) rename src/unstract/{migration => clone}/client.py (98%) rename src/unstract/{migration => clone}/context.py (84%) rename src/unstract/{migration => clone}/exceptions.py (64%) rename src/unstract/{migration => clone}/orchestrator.py (71%) create mode 100644 src/unstract/clone/phases/__init__.py rename src/unstract/{migration => clone}/phases/adapter.py (88%) rename src/unstract/{migration => clone}/phases/api_deployment.py (91%) rename src/unstract/{migration => clone}/phases/base.py (86%) rename src/unstract/{migration => clone}/phases/connector.py (91%) rename src/unstract/{migration => clone}/phases/custom_tool.py (96%) rename src/unstract/{migration => clone}/phases/files.py (95%) rename src/unstract/{migration => clone}/phases/pipeline.py (91%) rename src/unstract/{migration => clone}/phases/tag.py (85%) rename src/unstract/{migration => clone}/phases/tool_instance.py (85%) rename src/unstract/{migration => clone}/phases/workflow.py (86%) rename src/unstract/{migration => clone}/phases/workflow_endpoint.py (94%) rename src/unstract/{migration => clone}/report.py (66%) rename src/unstract/{migration => clone}/walker.py (95%) delete mode 100644 src/unstract/migration/__main__.py delete mode 100644 src/unstract/migration/phases/__init__.py rename tests/{migration => clone}/__init__.py (100%) rename tests/{migration => clone}/test_adapter_phase.py (90%) rename tests/{migration => clone}/test_api_deployment_phase.py (86%) rename tests/{migration => clone}/test_base_helpers.py (92%) rename tests/{migration => clone}/test_connector_phase.py (90%) rename tests/{migration => clone}/test_custom_tool_phase.py (93%) rename tests/{migration => clone}/test_files_phase.py (94%) rename tests/{migration => clone}/test_pipeline_phase.py (87%) rename tests/{migration => clone}/test_remap_table.py (95%) rename tests/{migration => clone}/test_tag_phase.py (85%) rename tests/{migration => clone}/test_tool_instance_phase.py (79%) rename tests/{migration => clone}/test_walker.py (93%) rename tests/{migration => clone}/test_workflow_endpoint_phase.py (90%) rename tests/{migration => clone}/test_workflow_phase.py (88%) diff --git a/pyproject.toml b/pyproject.toml index 0050c98..a208e3a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,13 +26,13 @@ classifiers = [ ] [project.optional-dependencies] -migration = [ +clone = [ "click>=8.1", "rich>=13.7", ] [project.scripts] -unstract-migrate = "unstract.migration.cli:main" +unstract-clone = "unstract.clone.cli:main" [build-system] requires = ["hatchling"] diff --git a/src/unstract/migration/README.md b/src/unstract/clone/README.md similarity index 55% rename from src/unstract/migration/README.md rename to src/unstract/clone/README.md index 9c2a69c..b952df4 100644 --- a/src/unstract/migration/README.md +++ b/src/unstract/clone/README.md @@ -1,18 +1,18 @@ -# Org-to-Org Migration +# Cloning Organizations -Move an Unstract organization's configured resources from one deployment to another (or between two orgs on the same deployment). +Clone an Unstract organization's configured resources into another organization (same deployment or different). Useful for environment promotion (DEV → QA → PROD) and for spinning up a fresh org from a known-good baseline. -Carried over: adapters, connectors, custom tools, prompts, profiles, workflows, tool instances, workflow endpoints, tags, API deployments, pipelines, and Prompt Studio document files. +Cloned resources: adapters, connectors, custom tools, prompts, profiles, workflows, tool instances, workflow endpoints, tags, API deployments, pipelines, and Prompt Studio document files. The source org is left untouched. > **Full documentation, behavior notes, CLI reference, and sample report:** -> https://docs.unstract.com/unstract/unstract_platform/api_documentation/versions/v1-org-migration/ +> https://docs.unstract.com/unstract/unstract_platform/api_documentation/versions/v1-org-cloning/ ## Quickstart ```bash UNSTRACT_SRC_PLATFORM_KEY=src_pk_... \ UNSTRACT_TGT_PLATFORM_KEY=tgt_pk_... \ -uv run python -m unstract.migration migrate \ +uv run python -m unstract.clone clone \ --source-url https://source.example.com \ --source-org my-source-org \ --target-url https://target.example.com \ @@ -22,7 +22,7 @@ uv run python -m unstract.migration migrate \ Both keys must be **org admin Platform API keys**. > [!WARNING] -> Both keys grant broad access. Run from a trusted machine and rotate both keys after the migration completes. +> Both keys grant broad access. Run from a trusted machine and rotate both keys after the clone completes. ## Re-runs are safe @@ -33,6 +33,6 @@ If a phase fails partway, fix the cause and re-run the same command. Resources a The Prompt Studio document corpus is the only resource type with bytes on disk. Default cap per file is 25 MB; oversize files are reported for manual re-upload. Use `--skip-files` to skip bytes entirely (document records are still created). > [!WARNING] -> Run migrations during low-activity windows. Concurrent uploads to the source org during a migration can create duplicate file records on the target. +> Run clones during low-activity windows. Concurrent uploads to the source org during a clone can create duplicate file records on the target. -See the [public docs](https://docs.unstract.com/unstract/unstract_platform/api_documentation/versions/v1-org-migration/) for the full flag list, behavioral notes, and the format of the end-of-run report. +See the [public docs](https://docs.unstract.com/unstract/unstract_platform/api_documentation/versions/v1-org-cloning/) for the full flag list, behavioral notes, and the format of the end-of-run report. diff --git a/src/unstract/migration/__init__.py b/src/unstract/clone/__init__.py similarity index 52% rename from src/unstract/migration/__init__.py rename to src/unstract/clone/__init__.py index 728b5b6..c36300b 100644 --- a/src/unstract/migration/__init__.py +++ b/src/unstract/clone/__init__.py @@ -1,4 +1,4 @@ -"""Org-to-org data migration over the Platform API. +"""Cloning organizations over the Platform API. Migrates configured resources (adapters, connectors, custom tools, workflows, etc.) from one Unstract org to another using two admin-issued Platform API @@ -6,20 +6,20 @@ against existing target rows by natural key. """ -from unstract.migration.context import ( - MigrationContext, - MigrationOptions, +from unstract.clone.context import ( + CloneContext, + CloneOptions, OrgEndpoint, RemapTable, ) -from unstract.migration.orchestrator import migrate -from unstract.migration.report import MigrationReport +from unstract.clone.orchestrator import clone +from unstract.clone.report import CloneReport __all__ = [ - "MigrationContext", - "MigrationOptions", - "MigrationReport", + "CloneContext", + "CloneOptions", + "CloneReport", "OrgEndpoint", "RemapTable", - "migrate", + "clone", ] diff --git a/src/unstract/clone/__main__.py b/src/unstract/clone/__main__.py new file mode 100644 index 0000000..2eef45c --- /dev/null +++ b/src/unstract/clone/__main__.py @@ -0,0 +1,6 @@ +"""Entry point: ``python -m unstract.clone``.""" + +from unstract.clone.cli import main + +if __name__ == "__main__": + main() diff --git a/src/unstract/migration/cli.py b/src/unstract/clone/cli.py similarity index 87% rename from src/unstract/migration/cli.py rename to src/unstract/clone/cli.py index 0c8339f..5ce15b8 100644 --- a/src/unstract/migration/cli.py +++ b/src/unstract/clone/cli.py @@ -1,6 +1,6 @@ -"""Click-based CLI for ``unstract.migration``. +"""Click-based CLI for ``unstract.clone``. -Single ``migrate`` command. Platform keys can be passed via flags +Single ``clone`` command. Platform keys can be passed via flags (``--source-key`` / ``--target-key``) or env vars (``UNSTRACT_SRC_PLATFORM_KEY`` / ``UNSTRACT_TGT_PLATFORM_KEY``) — env vars are preferred so the key never lands in shell history. @@ -15,13 +15,13 @@ import click -from unstract.migration.context import ( +from unstract.clone.context import ( DEFAULT_MAX_FILE_SIZE, - MigrationOptions, + CloneOptions, OrgEndpoint, ) -from unstract.migration.exceptions import MigrationError -from unstract.migration.orchestrator import migrate as run_migrate +from unstract.clone.exceptions import CloneError +from unstract.clone.orchestrator import clone as run_clone _SIZE_UNITS: dict[str, int] = { "B": 1, @@ -65,11 +65,11 @@ def _split_csv(value: str | None) -> tuple[str, ...] | None: @click.group() def cli() -> None: - """Org-to-org data migration over the Platform API.""" + """Cloning organizations over the Platform API.""" -@cli.command("migrate") -@click.option("--source-url", required=True, help="Base URL of the source deployment (e.g. https://us.unstract.com)") +@cli.command("clone") +@click.option("--source-url", required=True, help="Base URL of the source deployment") @click.option("--source-org", required=True, help="Source organization_id (slug in the URL path)") @click.option( "--source-key", @@ -128,7 +128,7 @@ def cli() -> None: help="Alias for --file-strategy=skip.", ) @click.option("-v", "--verbose", is_flag=True, help="Debug logging") -def migrate_cmd( +def clone_cmd( source_url: str, source_org: str, source_key: str, @@ -145,7 +145,7 @@ def migrate_cmd( skip_files: bool, verbose: bool, ) -> None: - """Migrate configured resources from one org to another.""" + """Clone configured resources from one org to another.""" _configure_logging(verbose) effective_strategy = "skip" if skip_files else file_strategy @@ -154,7 +154,7 @@ def migrate_cmd( except click.BadParameter as e: raise click.UsageError(str(e)) from e - options = MigrationOptions( + options = CloneOptions( dry_run=dry_run, include=_split_csv(include), exclude=_split_csv(exclude) or (), @@ -178,9 +178,9 @@ def migrate_cmd( ) try: - report = run_migrate(source, target, options) - except MigrationError as e: - click.echo(f"Migration failed: {e}", err=True) + report = run_clone(source, target, options) + except CloneError as e: + click.echo(f"Clone failed: {e}", err=True) sys.exit(2) click.echo(report.render()) diff --git a/src/unstract/migration/client.py b/src/unstract/clone/client.py similarity index 98% rename from src/unstract/migration/client.py rename to src/unstract/clone/client.py index 9f17c96..097394a 100644 --- a/src/unstract/migration/client.py +++ b/src/unstract/clone/client.py @@ -1,4 +1,4 @@ -"""Thin Platform API client for the migration subpackage. +"""Thin Platform API client for the clone subpackage. One ``PlatformClient`` instance per ``OrgEndpoint``. Methods are entity- scoped (``list_adapters``, ``create_adapter``, ...) so call sites in phases @@ -16,8 +16,8 @@ import requests -from unstract.migration.context import OrgEndpoint -from unstract.migration.exceptions import PlatformAPIError +from unstract.clone.context import OrgEndpoint +from unstract.clone.exceptions import PlatformAPIError logger = logging.getLogger(__name__) @@ -188,7 +188,7 @@ def update_custom_tool( def list_profiles(self, tool_id: str) -> list[dict[str, Any]]: """List ProfileManager rows for a tool. - Migration reads this on the source only — to discover the + The clone reads this on the source only — to discover the default profile's adapter UUIDs so they can be remapped to target adapter ids for ``import_project``. """ diff --git a/src/unstract/migration/context.py b/src/unstract/clone/context.py similarity index 84% rename from src/unstract/migration/context.py rename to src/unstract/clone/context.py index c3beec2..5a56be4 100644 --- a/src/unstract/migration/context.py +++ b/src/unstract/clone/context.py @@ -1,10 +1,10 @@ -"""Shared state passed between migration phases. +"""Shared state passed between clone phases. Three top-level types: - ``OrgEndpoint`` — base URL + organization_id + Platform API key for one org. -- ``MigrationOptions`` — run flags (dry-run, include/exclude, name-conflict). -- ``MigrationContext`` — bundles source/target clients, options, and the +- ``CloneOptions`` — run flags (dry-run, include/exclude, name-conflict). +- ``CloneContext`` — bundles source/target clients, options, and the per-run ``RemapTable``. ``RemapTable`` lives here too because every phase touches it. @@ -16,12 +16,12 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: - from unstract.migration.client import PlatformClient + from unstract.clone.client import PlatformClient @dataclass(frozen=True) class OrgEndpoint: - """One end of a migration: where to talk to and who to talk as. + """One end of a clone: where to talk to and who to talk as. ``organization_id`` is the slug embedded in the URL path; the bearer Platform API key must belong to this org. ``api_path_prefix`` matches @@ -38,8 +38,8 @@ class OrgEndpoint: @dataclass -class MigrationOptions: - """Per-run flags for ``migrate()``.""" +class CloneOptions: + """Per-run flags for ``clone()``.""" dry_run: bool = False include: tuple[str, ...] | None = None @@ -88,8 +88,8 @@ def snapshot(self) -> dict[str, dict[str, str]]: @dataclass -class MigrationContext: - """Shared state for one ``migrate()`` invocation. +class CloneContext: + """Shared state for one ``clone()`` invocation. Phases hold a reference to this and call ``ctx.source`` / ``ctx.target`` to drive HTTP, ``ctx.remap`` to record UUID mappings. @@ -97,5 +97,5 @@ class MigrationContext: source: PlatformClient target: PlatformClient - options: MigrationOptions + options: CloneOptions remap: RemapTable = field(default_factory=RemapTable) diff --git a/src/unstract/migration/exceptions.py b/src/unstract/clone/exceptions.py similarity index 64% rename from src/unstract/migration/exceptions.py rename to src/unstract/clone/exceptions.py index 572d9e5..47bffe2 100644 --- a/src/unstract/migration/exceptions.py +++ b/src/unstract/clone/exceptions.py @@ -1,11 +1,11 @@ -"""Exceptions raised by the migration subpackage.""" +"""Exceptions raised by the clone subpackage.""" -class MigrationError(Exception): - """Base class for all migration errors.""" +class CloneError(Exception): + """Base class for all clone errors.""" -class PlatformAPIError(MigrationError): +class PlatformAPIError(CloneError): """Raised when the Platform API returns a non-2xx response we can't recover from.""" def __init__(self, message: str, status_code: int | None = None, body: str | None = None): @@ -14,9 +14,9 @@ def __init__(self, message: str, status_code: int | None = None, body: str | Non self.body = body -class NameConflictError(MigrationError): +class NameConflictError(CloneError): """Raised when ``on_name_conflict='abort'`` and the target has a like-named entity.""" -class DependencyMissingError(MigrationError): +class DependencyMissingError(CloneError): """Raised when a phase references a source UUID that no prior phase has mapped.""" diff --git a/src/unstract/migration/orchestrator.py b/src/unstract/clone/orchestrator.py similarity index 71% rename from src/unstract/migration/orchestrator.py rename to src/unstract/clone/orchestrator.py index 6b5dc8b..2a21d54 100644 --- a/src/unstract/migration/orchestrator.py +++ b/src/unstract/clone/orchestrator.py @@ -1,8 +1,8 @@ -"""Top-level ``migrate()`` entry point. +"""Top-level ``clone()`` entry point. Wires source/target ``PlatformClient`` instances, builds a -``MigrationContext``, runs each phase in strict topological order, and -returns a ``MigrationReport``. +``CloneContext``, runs each phase in strict topological order, and +returns a ``CloneReport``. Phase order is owned here — phases must not call each other. Adding a new entity type means: write a new ``Phase`` subclass and append it to @@ -13,10 +13,10 @@ import logging -from unstract.migration.client import PlatformClient -from unstract.migration.context import MigrationContext, MigrationOptions, OrgEndpoint -from unstract.migration.exceptions import MigrationError -from unstract.migration.phases import ( +from unstract.clone.client import PlatformClient +from unstract.clone.context import CloneContext, CloneOptions, OrgEndpoint +from unstract.clone.exceptions import CloneError +from unstract.clone.phases import ( AdapterPhase, APIDeploymentPhase, ConnectorPhase, @@ -28,8 +28,8 @@ WorkflowEndpointPhase, WorkflowPhase, ) -from unstract.migration.phases.base import Phase -from unstract.migration.report import MigrationReport +from unstract.clone.phases.base import Phase +from unstract.clone.report import Endpoint, CloneReport logger = logging.getLogger(__name__) @@ -53,23 +53,26 @@ ] -def migrate( +def clone( source: OrgEndpoint, target: OrgEndpoint, - options: MigrationOptions | None = None, -) -> MigrationReport: + options: CloneOptions | None = None, +) -> CloneReport: """Migrate configured resources from one org to another. - Returns a ``MigrationReport`` even on partial failure; raises only on + Returns a ``CloneReport`` even on partial failure; raises only on setup errors or ``on_name_conflict='abort'`` collisions. """ - opts = options or MigrationOptions() - ctx = MigrationContext( + opts = options or CloneOptions() + ctx = CloneContext( source=PlatformClient(source), target=PlatformClient(target), options=opts, ) - report = MigrationReport() + report = CloneReport( + source=Endpoint(base_url=source.base_url, organization_id=source.organization_id), + target=Endpoint(base_url=target.base_url, organization_id=target.organization_id), + ) for name, phase_cls in PHASES: if not opts.includes(name): @@ -79,7 +82,7 @@ def migrate( logger.info("=== Phase: %s ===", name) try: phase_cls(ctx).run(report) - except MigrationError as e: + except CloneError as e: report.aborted = True report.abort_reason = str(e) logger.error("Phase '%s' aborted: %s", name, e) diff --git a/src/unstract/clone/phases/__init__.py b/src/unstract/clone/phases/__init__.py new file mode 100644 index 0000000..03f0952 --- /dev/null +++ b/src/unstract/clone/phases/__init__.py @@ -0,0 +1,34 @@ +"""Per-entity clone phases. + +Each phase implements ``run(report)``, uses ``ctx.source`` / ``ctx.target`` +to drive HTTP, records ``ctx.remap`` entries for downstream phases. + +Dependency order is owned by ``orchestrator.clone`` — phases must NOT +call each other directly. +""" + +from unstract.clone.phases.adapter import AdapterPhase +from unstract.clone.phases.api_deployment import APIDeploymentPhase +from unstract.clone.phases.base import Phase +from unstract.clone.phases.connector import ConnectorPhase +from unstract.clone.phases.custom_tool import CustomToolPhase +from unstract.clone.phases.files import FilesPhase +from unstract.clone.phases.pipeline import PipelinePhase +from unstract.clone.phases.tag import TagPhase +from unstract.clone.phases.tool_instance import ToolInstancePhase +from unstract.clone.phases.workflow import WorkflowPhase +from unstract.clone.phases.workflow_endpoint import WorkflowEndpointPhase + +__all__ = [ + "APIDeploymentPhase", + "AdapterPhase", + "ConnectorPhase", + "CustomToolPhase", + "FilesPhase", + "Phase", + "PipelinePhase", + "TagPhase", + "ToolInstancePhase", + "WorkflowEndpointPhase", + "WorkflowPhase", +] diff --git a/src/unstract/migration/phases/adapter.py b/src/unstract/clone/phases/adapter.py similarity index 88% rename from src/unstract/migration/phases/adapter.py rename to src/unstract/clone/phases/adapter.py index 711ef8c..9fb5668 100644 --- a/src/unstract/migration/phases/adapter.py +++ b/src/unstract/clone/phases/adapter.py @@ -5,7 +5,7 @@ remap table for downstream phases. Frictionless onboarding adapters are excluded — the backend's -service-account queryset already filters them out, so migration never +service-account queryset already filters them out, so clone never sees them. """ @@ -14,9 +14,9 @@ import logging from typing import Any -from unstract.migration.exceptions import NameConflictError -from unstract.migration.phases.base import Phase, build_post_payload -from unstract.migration.report import MigrationReport, PhaseResult +from unstract.clone.exceptions import NameConflictError +from unstract.clone.phases.base import Phase, build_post_payload +from unstract.clone.report import CloneReport, PhaseResult logger = logging.getLogger(__name__) @@ -26,7 +26,7 @@ class AdapterPhase(Phase): name = "adapter" - def run(self, report: MigrationReport) -> PhaseResult: + def run(self, report: CloneReport) -> PhaseResult: result = report.get_phase(self.name) try: self._writable = self.ctx.target.get_post_schema(ADAPTER_PATH) @@ -45,10 +45,10 @@ def run(self, report: MigrationReport) -> PhaseResult: logger.info("Found %d adapter(s) in source org", len(src_summaries)) for summary in src_summaries: - self._migrate_one(summary, result) + self._clone_one(summary, result) return result - def _migrate_one(self, summary: dict[str, Any], result: PhaseResult) -> None: + def _clone_one(self, summary: dict[str, Any], result: PhaseResult) -> None: name = summary["adapter_name"] atype = summary["adapter_type"] src_id = summary["id"] diff --git a/src/unstract/migration/phases/api_deployment.py b/src/unstract/clone/phases/api_deployment.py similarity index 91% rename from src/unstract/migration/phases/api_deployment.py rename to src/unstract/clone/phases/api_deployment.py index e88beff..00c85b7 100644 --- a/src/unstract/migration/phases/api_deployment.py +++ b/src/unstract/clone/phases/api_deployment.py @@ -8,7 +8,7 @@ On create the backend auto-provisions a single active API key and returns it on the response. Extra rotated keys on the source are NOT mirrored (server-generated UUIDs can't be preserved; rotate -post-migration). +post-clone). """ from __future__ import annotations @@ -16,10 +16,10 @@ import logging from typing import Any -from unstract.migration.exceptions import NameConflictError -from unstract.migration.phases.base import Phase, build_post_payload -from unstract.migration.report import MigrationReport, PhaseResult -from unstract.migration.walker import remap_uuids +from unstract.clone.exceptions import NameConflictError +from unstract.clone.phases.base import Phase, build_post_payload +from unstract.clone.report import CloneReport, PhaseResult +from unstract.clone.walker import remap_uuids logger = logging.getLogger(__name__) @@ -29,7 +29,7 @@ class APIDeploymentPhase(Phase): name = "api_deployment" - def run(self, report: MigrationReport) -> PhaseResult: + def run(self, report: CloneReport) -> PhaseResult: result = report.get_phase(self.name) try: self._writable = self.ctx.target.get_post_schema(API_DEPLOYMENT_PATH) @@ -51,10 +51,10 @@ def run(self, report: MigrationReport) -> PhaseResult: logger.info("Found %d source API deployment(s)", len(src_deployments)) for src in src_deployments: - self._migrate_one(src, result) + self._clone_one(src, result) return result - def _migrate_one(self, src: dict[str, Any], result: PhaseResult) -> None: + def _clone_one(self, src: dict[str, Any], result: PhaseResult) -> None: api_name = src["api_name"] src_id = src["id"] src_wf_id = src.get("workflow") or src.get("workflow_id") diff --git a/src/unstract/migration/phases/base.py b/src/unstract/clone/phases/base.py similarity index 86% rename from src/unstract/migration/phases/base.py rename to src/unstract/clone/phases/base.py index e8dd5a1..6208192 100644 --- a/src/unstract/migration/phases/base.py +++ b/src/unstract/clone/phases/base.py @@ -1,4 +1,4 @@ -"""Base class for migration phases.""" +"""Base class for clone phases.""" from __future__ import annotations @@ -6,8 +6,8 @@ from abc import ABC, abstractmethod from typing import Any -from unstract.migration.context import MigrationContext -from unstract.migration.report import MigrationReport, PhaseResult +from unstract.clone.context import CloneContext +from unstract.clone.report import CloneReport, PhaseResult logger = logging.getLogger(__name__) @@ -55,10 +55,10 @@ class Phase(ABC): name: str = "" - def __init__(self, ctx: MigrationContext): + def __init__(self, ctx: CloneContext): self.ctx = ctx @abstractmethod - def run(self, report: MigrationReport) -> PhaseResult: + def run(self, report: CloneReport) -> PhaseResult: """Migrate all entities of this phase's type. Idempotent across runs.""" raise NotImplementedError diff --git a/src/unstract/migration/phases/connector.py b/src/unstract/clone/phases/connector.py similarity index 91% rename from src/unstract/migration/phases/connector.py rename to src/unstract/clone/phases/connector.py index 2e2ea33..816456f 100644 --- a/src/unstract/migration/phases/connector.py +++ b/src/unstract/clone/phases/connector.py @@ -22,9 +22,9 @@ import logging from typing import Any -from unstract.migration.exceptions import NameConflictError -from unstract.migration.phases.base import Phase, build_post_payload -from unstract.migration.report import MigrationReport, PhaseResult +from unstract.clone.exceptions import NameConflictError +from unstract.clone.phases.base import Phase, build_post_payload +from unstract.clone.report import CloneReport, PhaseResult logger = logging.getLogger(__name__) @@ -34,7 +34,7 @@ class ConnectorPhase(Phase): name = "connector" - def run(self, report: MigrationReport) -> PhaseResult: + def run(self, report: CloneReport) -> PhaseResult: result = report.get_phase(self.name) try: self._writable = self.ctx.target.get_post_schema(CONNECTOR_PATH) @@ -53,10 +53,10 @@ def run(self, report: MigrationReport) -> PhaseResult: logger.info("Found %d connector(s) in source org", len(src_summaries)) for summary in src_summaries: - self._migrate_one(summary, result) + self._clone_one(summary, result) return result - def _migrate_one(self, summary: dict[str, Any], result: PhaseResult) -> None: + def _clone_one(self, summary: dict[str, Any], result: PhaseResult) -> None: name = summary["connector_name"] src_id = summary["id"] diff --git a/src/unstract/migration/phases/custom_tool.py b/src/unstract/clone/phases/custom_tool.py similarity index 96% rename from src/unstract/migration/phases/custom_tool.py rename to src/unstract/clone/phases/custom_tool.py index c7bec8c..8cb03ca 100644 --- a/src/unstract/migration/phases/custom_tool.py +++ b/src/unstract/clone/phases/custom_tool.py @@ -30,9 +30,9 @@ import logging from typing import Any -from unstract.migration.exceptions import NameConflictError -from unstract.migration.phases.base import Phase -from unstract.migration.report import MigrationReport, PhaseResult +from unstract.clone.exceptions import NameConflictError +from unstract.clone.phases.base import Phase +from unstract.clone.report import CloneReport, PhaseResult logger = logging.getLogger(__name__) @@ -59,7 +59,7 @@ def _extract_adapter_name(value: Any) -> str | None: class CustomToolPhase(Phase): name = "custom_tool" - def run(self, report: MigrationReport) -> PhaseResult: + def run(self, report: CloneReport) -> PhaseResult: result = report.get_phase(self.name) try: src_tools = self.ctx.source.list_custom_tools() @@ -82,10 +82,10 @@ def run(self, report: MigrationReport) -> PhaseResult: result.errors.append(f"list target tools: {e}") return for summary in src_tools: - self._migrate_one(summary, target_tools, result) + self._clone_one(summary, target_tools, result) return result - def _migrate_one( + def _clone_one( self, summary: dict[str, Any], target_tools: list[dict[str, Any]], @@ -145,7 +145,7 @@ def _migrate_one( except Exception as e: logger.warning( "registry remap lookup failed for tool '%s' " - "(downstream ToolInstance migration may skip): %s", + "(downstream ToolInstance clone may skip): %s", tool_name, e, ) return @@ -241,7 +241,7 @@ def _resolve_target_adapter_ids( Returns ``None`` if any of the four required adapters can't be found on target — caller fails the tool. AdapterPhase preserves names across orgs so this lookup should always hit when the - adapter migration ran cleanly. + adapter clone ran cleanly. """ try: src_profiles = self.ctx.source.list_profiles(src_tool_id) diff --git a/src/unstract/migration/phases/files.py b/src/unstract/clone/phases/files.py similarity index 95% rename from src/unstract/migration/phases/files.py rename to src/unstract/clone/phases/files.py index 868067c..da403f2 100644 --- a/src/unstract/migration/phases/files.py +++ b/src/unstract/clone/phases/files.py @@ -9,7 +9,7 @@ DM rows once each. 2. For each source filename missing on target: download from source, decode per mime, enforce the size cap, upload as multipart to target. -3. Oversize files → ``MigrationReport.oversize_files``; mime types the +3. Oversize files → ``CloneReport.oversize_files``; mime types the backend can't round-trip (Excel placeholder, etc) → ``unsupported_files``; transport errors → ``failed_files``. @@ -20,7 +20,7 @@ Concurrency is 1 per phase by design — the Platform API endpoint holds a cloud worker for the whole upload, and uploads are not chunked on the BE -helper today. See ``docs/internal/files-migration-plan.md`` for the +helper today. See ``docs/internal/files-clone-plan.md`` for the sizing rationale. """ @@ -33,9 +33,9 @@ import requests -from unstract.migration.exceptions import PlatformAPIError -from unstract.migration.phases.base import Phase -from unstract.migration.report import MigrationReport, PhaseResult +from unstract.clone.exceptions import PlatformAPIError +from unstract.clone.phases.base import Phase +from unstract.clone.report import CloneReport, PhaseResult logger = logging.getLogger(__name__) @@ -53,7 +53,7 @@ class FilesPhase(Phase): name = "files" - def run(self, report: MigrationReport) -> PhaseResult: + def run(self, report: CloneReport) -> PhaseResult: result = report.get_phase(self.name) tool_remap = self.ctx.remap.snapshot().get("custom_tool", {}) if not tool_remap: @@ -85,19 +85,19 @@ def run(self, report: MigrationReport) -> PhaseResult: self._emit_skip(src_docs, src_tool_id, tgt_tool_id, tool_name, report, result) continue - self._migrate_tool( + self._clone_tool( src_tool_id, tgt_tool_id, tool_name, src_docs, report, result ) return result - def _migrate_tool( + def _clone_tool( self, src_tool_id: str, tgt_tool_id: str, tool_name: str, src_docs: list[dict[str, Any]], - report: MigrationReport, + report: CloneReport, result: PhaseResult, ) -> None: try: @@ -127,11 +127,11 @@ def _migrate_tool( if self.ctx.options.dry_run: result.skipped += 1 logger.info( - "[dry-run] files: would migrate tool=%s file=%s", + "[dry-run] files: would clone tool=%s file=%s", tool_name, file_name, ) continue - self._migrate_one_file( + self._clone_one_file( src_tool_id, tgt_tool_id, tool_name, @@ -146,14 +146,14 @@ def _migrate_tool( src_tool_id, tgt_tool_id, tool_name, src_docs ) - def _migrate_one_file( + def _clone_one_file( self, src_tool_id: str, tgt_tool_id: str, tool_name: str, file_name: str, src_document_id: str, - report: MigrationReport, + report: CloneReport, result: PhaseResult, ) -> None: try: @@ -256,7 +256,7 @@ def _emit_skip( src_tool_id: str, tgt_tool_id: str, tool_name: str, - report: MigrationReport, + report: CloneReport, result: PhaseResult, ) -> None: for doc in src_docs: diff --git a/src/unstract/migration/phases/pipeline.py b/src/unstract/clone/phases/pipeline.py similarity index 91% rename from src/unstract/migration/phases/pipeline.py rename to src/unstract/clone/phases/pipeline.py index 8a05487..15b1835 100644 --- a/src/unstract/migration/phases/pipeline.py +++ b/src/unstract/clone/phases/pipeline.py @@ -4,7 +4,7 @@ backend force-sets ``active=True`` and auto-provisions one active API key per pipeline; if the source had additional rotated keys, those are NOT mirrored (their UUIDs are server-generated and can't be preserved, -and operators rotate post-migration anyway). +and operators rotate post-clone anyway). ``DEFAULT`` (legacy) and ``APP`` pipeline types are skipped — DEFAULT is dead code from the v1 era; APP is a Streamlit-style deployment whose @@ -16,10 +16,10 @@ import logging from typing import Any -from unstract.migration.exceptions import NameConflictError -from unstract.migration.phases.base import Phase, build_post_payload -from unstract.migration.report import MigrationReport, PhaseResult -from unstract.migration.walker import remap_uuids +from unstract.clone.exceptions import NameConflictError +from unstract.clone.phases.base import Phase, build_post_payload +from unstract.clone.report import CloneReport, PhaseResult +from unstract.clone.walker import remap_uuids logger = logging.getLogger(__name__) @@ -30,7 +30,7 @@ class PipelinePhase(Phase): name = "pipeline" - def run(self, report: MigrationReport) -> PhaseResult: + def run(self, report: CloneReport) -> PhaseResult: result = report.get_phase(self.name) try: self._writable = self.ctx.target.get_post_schema(PIPELINE_PATH) @@ -61,10 +61,10 @@ def run(self, report: MigrationReport) -> PhaseResult: logger.info("Found %d source pipeline(s)", len(src_pipelines)) for src in migratable: - self._migrate_one(src, result) + self._clone_one(src, result) return result - def _migrate_one(self, src: dict[str, Any], result: PhaseResult) -> None: + def _clone_one(self, src: dict[str, Any], result: PhaseResult) -> None: name = src["pipeline_name"] src_id = src["id"] src_wf_id = src.get("workflow") or src.get("workflow_id") diff --git a/src/unstract/migration/phases/tag.py b/src/unstract/clone/phases/tag.py similarity index 85% rename from src/unstract/migration/phases/tag.py rename to src/unstract/clone/phases/tag.py index 9381a26..e6ae31a 100644 --- a/src/unstract/migration/phases/tag.py +++ b/src/unstract/clone/phases/tag.py @@ -2,7 +2,7 @@ Tags are flat (``name`` + ``description``) with a per-org uniqueness constraint on ``name``. No metadata, no encryption, no list-vs-detail -divergence — the simplest entity in the migration set. +divergence — the simplest entity in the clone set. List endpoint paginates; ``PlatformClient.list_tags`` already unwraps the envelope. @@ -13,9 +13,9 @@ import logging from typing import Any -from unstract.migration.exceptions import NameConflictError -from unstract.migration.phases.base import Phase, build_post_payload -from unstract.migration.report import MigrationReport, PhaseResult +from unstract.clone.exceptions import NameConflictError +from unstract.clone.phases.base import Phase, build_post_payload +from unstract.clone.report import CloneReport, PhaseResult logger = logging.getLogger(__name__) @@ -25,7 +25,7 @@ class TagPhase(Phase): name = "tag" - def run(self, report: MigrationReport) -> PhaseResult: + def run(self, report: CloneReport) -> PhaseResult: result = report.get_phase(self.name) try: self._writable = self.ctx.target.get_post_schema(TAG_PATH) @@ -44,10 +44,10 @@ def run(self, report: MigrationReport) -> PhaseResult: logger.info("Found %d tag(s) in source org", len(src_tags)) for src in src_tags: - self._migrate_one(src, result) + self._clone_one(src, result) return result - def _migrate_one(self, src: dict[str, Any], result: PhaseResult) -> None: + def _clone_one(self, src: dict[str, Any], result: PhaseResult) -> None: name = src["name"] src_id = src["id"] diff --git a/src/unstract/migration/phases/tool_instance.py b/src/unstract/clone/phases/tool_instance.py similarity index 85% rename from src/unstract/migration/phases/tool_instance.py rename to src/unstract/clone/phases/tool_instance.py index efe5a26..ac41a9b 100644 --- a/src/unstract/migration/phases/tool_instance.py +++ b/src/unstract/clone/phases/tool_instance.py @@ -22,8 +22,8 @@ import logging from typing import Any -from unstract.migration.phases.base import Phase -from unstract.migration.report import MigrationReport, PhaseResult +from unstract.clone.phases.base import Phase +from unstract.clone.report import CloneReport, PhaseResult logger = logging.getLogger(__name__) @@ -39,6 +39,18 @@ "[NEEDS UPDATE]", ) +# Identity fields that point at backend rows by primary key. They were +# populated server-side at create time on source and must NOT be carried +# across orgs — the target's create_tool_instance has already set the +# correct target values. Leaking source ids here makes the structure +# tool fetch the source registry at runtime (platform-service looks up +# registries by id only, no org scope) and load the wrong adapters. +_SOURCE_IDENTITY_FIELDS: tuple[str, ...] = ( + "prompt_registry_id", + "tool_instance_id", + "tenant_id", +) + def _broken_adapter_keys(metadata: dict[str, Any]) -> list[str]: broken: list[str] = [] @@ -50,10 +62,14 @@ def _broken_adapter_keys(metadata: dict[str, Any]) -> list[str]: return broken +def _strip_source_identity(metadata: dict[str, Any]) -> dict[str, Any]: + return {k: v for k, v in metadata.items() if k not in _SOURCE_IDENTITY_FIELDS} + + class ToolInstancePhase(Phase): name = "tool_instance" - def run(self, report: MigrationReport) -> PhaseResult: + def run(self, report: CloneReport) -> PhaseResult: result = report.get_phase(self.name) workflow_remap = self.ctx.remap.snapshot().get("workflow", {}) if not workflow_remap: @@ -61,10 +77,10 @@ def run(self, report: MigrationReport) -> PhaseResult: return result for src_wf_id, tgt_wf_id in workflow_remap.items(): - self._migrate_workflow_tools(src_wf_id, tgt_wf_id, result) + self._clone_workflow_tools(src_wf_id, tgt_wf_id, result) return result - def _migrate_workflow_tools( + def _clone_workflow_tools( self, src_wf_id: str, tgt_wf_id: str, result: PhaseResult ) -> None: try: @@ -154,9 +170,10 @@ def _migrate_workflow_tools( f"stale adapter refs on src tool_instance {src_ti_id}: {broken}" ) else: + patch_metadata = _strip_source_identity(src_metadata) try: self.ctx.target.update_tool_instance_metadata( - tgt_ti["id"], src_metadata + tgt_ti["id"], patch_metadata ) except Exception as e: logger.exception( diff --git a/src/unstract/migration/phases/workflow.py b/src/unstract/clone/phases/workflow.py similarity index 86% rename from src/unstract/migration/phases/workflow.py rename to src/unstract/clone/phases/workflow.py index 49917ee..f8095b9 100644 --- a/src/unstract/migration/phases/workflow.py +++ b/src/unstract/clone/phases/workflow.py @@ -1,6 +1,6 @@ """Migrate workflows from source org to target org. -Workflow rows themselves are simple — no required FKs to migration +Workflow rows themselves are simple — no required FKs to clone entities, unique per ``(workflow_name, organization)``. The two non-trivial bits: @@ -18,10 +18,10 @@ import logging from typing import Any -from unstract.migration.exceptions import NameConflictError -from unstract.migration.phases.base import Phase, build_post_payload -from unstract.migration.report import MigrationReport, PhaseResult -from unstract.migration.walker import remap_uuids +from unstract.clone.exceptions import NameConflictError +from unstract.clone.phases.base import Phase, build_post_payload +from unstract.clone.report import CloneReport, PhaseResult +from unstract.clone.walker import remap_uuids logger = logging.getLogger(__name__) @@ -31,7 +31,7 @@ class WorkflowPhase(Phase): name = "workflow" - def run(self, report: MigrationReport) -> PhaseResult: + def run(self, report: CloneReport) -> PhaseResult: result = report.get_phase(self.name) try: self._writable = self.ctx.target.get_post_schema(WORKFLOW_PATH) @@ -51,10 +51,10 @@ def run(self, report: MigrationReport) -> PhaseResult: logger.info("Found %d workflow(s) in source org", len(src_workflows)) for src in src_workflows: - self._migrate_one(src, result) + self._clone_one(src, result) return result - def _migrate_one(self, src: dict[str, Any], result: PhaseResult) -> None: + def _clone_one(self, src: dict[str, Any], result: PhaseResult) -> None: name = src["workflow_name"] src_id = src["id"] diff --git a/src/unstract/migration/phases/workflow_endpoint.py b/src/unstract/clone/phases/workflow_endpoint.py similarity index 94% rename from src/unstract/migration/phases/workflow_endpoint.py rename to src/unstract/clone/phases/workflow_endpoint.py index 3e53995..a54f34a 100644 --- a/src/unstract/migration/phases/workflow_endpoint.py +++ b/src/unstract/clone/phases/workflow_endpoint.py @@ -22,9 +22,9 @@ import logging from typing import Any -from unstract.migration.phases.base import Phase -from unstract.migration.report import MigrationReport, PhaseResult -from unstract.migration.walker import remap_uuids +from unstract.clone.phases.base import Phase +from unstract.clone.report import CloneReport, PhaseResult +from unstract.clone.walker import remap_uuids logger = logging.getLogger(__name__) @@ -42,7 +42,7 @@ def _extract_connector_id(endpoint: dict[str, Any]) -> str | None: class WorkflowEndpointPhase(Phase): name = "workflow_endpoint" - def run(self, report: MigrationReport) -> PhaseResult: + def run(self, report: CloneReport) -> PhaseResult: result = report.get_phase(self.name) workflow_remap = self.ctx.remap.snapshot().get("workflow", {}) if not workflow_remap: @@ -50,10 +50,10 @@ def run(self, report: MigrationReport) -> PhaseResult: return result for src_wf_id, tgt_wf_id in workflow_remap.items(): - self._migrate_workflow_endpoints(src_wf_id, tgt_wf_id, result) + self._clone_workflow_endpoints(src_wf_id, tgt_wf_id, result) return result - def _migrate_workflow_endpoints( + def _clone_workflow_endpoints( self, src_wf_id: str, tgt_wf_id: str, result: PhaseResult ) -> None: try: diff --git a/src/unstract/migration/report.py b/src/unstract/clone/report.py similarity index 66% rename from src/unstract/migration/report.py rename to src/unstract/clone/report.py index fa06568..9b1b9ce 100644 --- a/src/unstract/migration/report.py +++ b/src/unstract/clone/report.py @@ -1,4 +1,4 @@ -"""Structured report produced by ``migrate()``. +"""Structured report produced by ``clone()``. Tracks per-phase counts (created / adopted / skipped / failed) and a final remap snapshot. Renders to a rich-formatted table when ``rich`` is @@ -25,7 +25,17 @@ class PhaseResult: @dataclass -class MigrationReport: +class Endpoint: + """Just enough about an endpoint for the report header — never carries the API key.""" + + base_url: str + organization_id: str + + +@dataclass +class CloneReport: + source: Endpoint | None = None + target: Endpoint | None = None phases: list[PhaseResult] = field(default_factory=list) skipped_phases: list[str] = field(default_factory=list) remap_snapshot: dict[str, dict[str, str]] = field(default_factory=dict) @@ -58,23 +68,63 @@ def render(self) -> str: return self._render_plain() buf = StringIO() - console = Console(file=buf, force_terminal=False, width=100) - table = Table(title="Migration Report") - for col in ("Phase", "Created", "Adopted", "Skipped", "Failed"): - table.add_column(col, justify="right" if col != "Phase" else "left") + # force_terminal so ANSI codes survive the StringIO capture; the + # caller decides whether to strip them when printing to a non-tty. + console = Console(file=buf, force_terminal=True, color_system="truecolor", width=100) + self._render_endpoints(console.print) + table = Table(title="Clone Report", header_style="bold cyan") + table.add_column("Phase", style="bold", justify="left") + for col in ("Created", "Adopted", "Skipped", "Failed"): + table.add_column(col, justify="right") + + totals = {"created": 0, "adopted": 0, "skipped": 0, "failed": 0} for p in self.phases: - table.add_row(p.name, str(p.created), str(p.adopted), str(p.skipped), str(p.failed)) + phase_style = "red" if p.failed else ("yellow" if p.skipped else "green") + table.add_row( + f"[{phase_style}]{p.name}[/{phase_style}]", + self._fmt_count(p.created, "green"), + self._fmt_count(p.adopted, "green"), + self._fmt_count(p.skipped, "yellow"), + self._fmt_count(p.failed, "red"), + ) + for k in totals: + totals[k] += getattr(p, k) + + table.add_section() + table.add_row( + "[bold]TOTAL[/bold]", + self._fmt_count(totals["created"], "green", bold=True), + self._fmt_count(totals["adopted"], "green", bold=True), + self._fmt_count(totals["skipped"], "yellow", bold=True), + self._fmt_count(totals["failed"], "red", bold=True), + ) console.print(table) if self.skipped_phases: console.print(f"[dim]Skipped phases:[/dim] {', '.join(self.skipped_phases)}") self._render_files_sections(console) self._render_remap_summary(console_print=console.print) if self.aborted: - console.print(f"[red]ABORTED:[/red] {self.abort_reason}") + console.print(f"[bold red]ABORTED:[/bold red] {self.abort_reason}") + elif totals["failed"]: + console.print( + f"[bold red]Completed with {totals['failed']} failure(s)[/bold red] — " + "see WARNING/ERROR log lines above for details" + ) + else: + console.print("[bold green]Completed successfully[/bold green]") return buf.getvalue() + @staticmethod + def _fmt_count(value: int, color: str, bold: bool = False) -> str: + """Dim a zero to keep the eye on non-zero cells; colour anything > 0.""" + if value == 0: + return "[dim]0[/dim]" + style = f"bold {color}" if bold else color + return f"[{style}]{value}[/{style}]" + def _render_plain(self) -> str: - lines = ["Migration Report", "=" * 60] + lines = ["Clone Report", "=" * 60] + self._render_endpoints(lines.append) header = f"{'Phase':<24}{'Created':>10}{'Adopted':>10}{'Skipped':>10}{'Failed':>10}" lines.append(header) for p in self.phases: @@ -91,6 +141,16 @@ def _render_plain(self) -> str: def as_dict(self) -> dict[str, Any]: return { + "source": ( + {"base_url": self.source.base_url, "organization_id": self.source.organization_id} + if self.source + else None + ), + "target": ( + {"base_url": self.target.base_url, "organization_id": self.target.organization_id} + if self.target + else None + ), "phases": [ { "name": p.name, @@ -113,6 +173,20 @@ def as_dict(self) -> dict[str, Any]: "failed_files": list(self.failed_files), } + def _render_endpoints(self, console_print: Any) -> None: + if not self.source and not self.target: + return + src = self._fmt_endpoint(self.source) + tgt = self._fmt_endpoint(self.target) + console_print(f"Source: {src}") + console_print(f"Target: {tgt}") + + @staticmethod + def _fmt_endpoint(ep: Endpoint | None) -> str: + if ep is None: + return "?" + return f"{ep.organization_id} @ {ep.base_url}" + def _render_remap_summary(self, console_print: Any) -> None: """Summarise the remap snapshot. Full map is large and noisy, so we only print per-entity counts here; the full mapping is emitted diff --git a/src/unstract/migration/walker.py b/src/unstract/clone/walker.py similarity index 95% rename from src/unstract/migration/walker.py rename to src/unstract/clone/walker.py index 6e43553..eb9c401 100644 --- a/src/unstract/migration/walker.py +++ b/src/unstract/clone/walker.py @@ -11,7 +11,7 @@ import re from typing import Any -from unstract.migration.context import RemapTable +from unstract.clone.context import RemapTable UUID_RE = re.compile( r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$", diff --git a/src/unstract/migration/__main__.py b/src/unstract/migration/__main__.py deleted file mode 100644 index c9d3fd2..0000000 --- a/src/unstract/migration/__main__.py +++ /dev/null @@ -1,6 +0,0 @@ -"""Entry point: ``python -m unstract.migration``.""" - -from unstract.migration.cli import main - -if __name__ == "__main__": - main() diff --git a/src/unstract/migration/phases/__init__.py b/src/unstract/migration/phases/__init__.py deleted file mode 100644 index 00e8312..0000000 --- a/src/unstract/migration/phases/__init__.py +++ /dev/null @@ -1,34 +0,0 @@ -"""Per-entity migration phases. - -Each phase implements ``run(report)``, uses ``ctx.source`` / ``ctx.target`` -to drive HTTP, records ``ctx.remap`` entries for downstream phases. - -Dependency order is owned by ``orchestrator.migrate`` — phases must NOT -call each other directly. -""" - -from unstract.migration.phases.adapter import AdapterPhase -from unstract.migration.phases.api_deployment import APIDeploymentPhase -from unstract.migration.phases.base import Phase -from unstract.migration.phases.connector import ConnectorPhase -from unstract.migration.phases.custom_tool import CustomToolPhase -from unstract.migration.phases.files import FilesPhase -from unstract.migration.phases.pipeline import PipelinePhase -from unstract.migration.phases.tag import TagPhase -from unstract.migration.phases.tool_instance import ToolInstancePhase -from unstract.migration.phases.workflow import WorkflowPhase -from unstract.migration.phases.workflow_endpoint import WorkflowEndpointPhase - -__all__ = [ - "APIDeploymentPhase", - "AdapterPhase", - "ConnectorPhase", - "CustomToolPhase", - "FilesPhase", - "Phase", - "PipelinePhase", - "TagPhase", - "ToolInstancePhase", - "WorkflowEndpointPhase", - "WorkflowPhase", -] diff --git a/tests/migration/__init__.py b/tests/clone/__init__.py similarity index 100% rename from tests/migration/__init__.py rename to tests/clone/__init__.py diff --git a/tests/migration/test_adapter_phase.py b/tests/clone/test_adapter_phase.py similarity index 90% rename from tests/migration/test_adapter_phase.py rename to tests/clone/test_adapter_phase.py index 4530a86..d2b0311 100644 --- a/tests/migration/test_adapter_phase.py +++ b/tests/clone/test_adapter_phase.py @@ -11,14 +11,14 @@ import pytest -from unstract.migration.context import ( - MigrationContext, - MigrationOptions, +from unstract.clone.context import ( + CloneContext, + CloneOptions, RemapTable, ) -from unstract.migration.exceptions import NameConflictError -from unstract.migration.phases.adapter import AdapterPhase -from unstract.migration.report import MigrationReport +from unstract.clone.exceptions import NameConflictError +from unstract.clone.phases.adapter import AdapterPhase +from unstract.clone.report import CloneReport class FakeClient: @@ -74,10 +74,10 @@ def _src_adapter(id_, name, atype="LLM"): def _ctx(source: FakeClient, target: FakeClient, **opt_overrides): - ctx = MigrationContext( + ctx = CloneContext( source=source, target=target, - options=MigrationOptions(**opt_overrides), + options=CloneOptions(**opt_overrides), remap=RemapTable(), ) return ctx @@ -92,7 +92,7 @@ def test_happy_path_creates_all_and_records_remap(): ) tgt = FakeClient() ctx = _ctx(src, tgt) - report = MigrationReport() + report = CloneReport() result = AdapterPhase(ctx).run(report) @@ -120,7 +120,7 @@ def test_idempotency_zero_creates_on_rerun(): ] ) ctx = _ctx(src, tgt, on_name_conflict="adopt") - report = MigrationReport() + report = CloneReport() result = AdapterPhase(ctx).run(report) @@ -134,7 +134,7 @@ def test_dry_run_makes_no_posts(): src = FakeClient([_src_adapter("src-a", "OpenAI Prod")]) tgt = FakeClient() ctx = _ctx(src, tgt, dry_run=True) - report = MigrationReport() + report = CloneReport() result = AdapterPhase(ctx).run(report) @@ -157,7 +157,7 @@ def test_abort_on_name_conflict_raises(): ] ) ctx = _ctx(src, tgt, on_name_conflict="abort") - report = MigrationReport() + report = CloneReport() with pytest.raises(NameConflictError): AdapterPhase(ctx).run(report) diff --git a/tests/migration/test_api_deployment_phase.py b/tests/clone/test_api_deployment_phase.py similarity index 86% rename from tests/migration/test_api_deployment_phase.py rename to tests/clone/test_api_deployment_phase.py index a80900c..df84c91 100644 --- a/tests/migration/test_api_deployment_phase.py +++ b/tests/clone/test_api_deployment_phase.py @@ -15,14 +15,14 @@ import pytest -from unstract.migration.context import ( - MigrationContext, - MigrationOptions, +from unstract.clone.context import ( + CloneContext, + CloneOptions, RemapTable, ) -from unstract.migration.exceptions import NameConflictError -from unstract.migration.phases.api_deployment import APIDeploymentPhase -from unstract.migration.report import MigrationReport +from unstract.clone.exceptions import NameConflictError +from unstract.clone.phases.api_deployment import APIDeploymentPhase +from unstract.clone.report import CloneReport API_DEPLOYMENT_POST_SCHEMA = frozenset( @@ -84,10 +84,10 @@ def _src_deployment( def _ctx(source, target, *, remap=None, **opt_overrides): - return MigrationContext( + return CloneContext( source=source, target=target, - options=MigrationOptions(**opt_overrides), + options=CloneOptions(**opt_overrides), remap=remap or RemapTable(), ) @@ -99,7 +99,7 @@ def test_happy_path_creates_deployment_with_remapped_workflow(): remap.record("workflow", "wf-src-1", "wf-tgt-1") ctx = _ctx(src, tgt, remap=remap) - result = APIDeploymentPhase(ctx).run(MigrationReport()) + result = APIDeploymentPhase(ctx).run(CloneReport()) assert result.created == 1 assert result.failed == 0 @@ -118,7 +118,7 @@ def test_adopts_existing_deployment_by_api_name(): remap.record("workflow", "wf-src-1", "wf-tgt-1") ctx = _ctx(src, tgt, remap=remap) - result = APIDeploymentPhase(ctx).run(MigrationReport()) + result = APIDeploymentPhase(ctx).run(CloneReport()) assert result.adopted == 1 assert result.created == 0 @@ -131,7 +131,7 @@ def test_skipped_when_workflow_remap_missing(): tgt = FakeClient() ctx = _ctx(src, tgt) # No workflow remap. - result = APIDeploymentPhase(ctx).run(MigrationReport()) + result = APIDeploymentPhase(ctx).run(CloneReport()) assert result.skipped == 1 assert tgt.posts == [] @@ -144,7 +144,7 @@ def test_dry_run_makes_no_writes(): remap.record("workflow", "wf-src-1", "wf-tgt-1") ctx = _ctx(src, tgt, remap=remap, dry_run=True) - result = APIDeploymentPhase(ctx).run(MigrationReport()) + result = APIDeploymentPhase(ctx).run(CloneReport()) assert result.skipped == 1 assert tgt.posts == [] @@ -158,7 +158,7 @@ def test_abort_on_name_conflict_raises(): ctx = _ctx(src, tgt, remap=remap, on_name_conflict="abort") with pytest.raises(NameConflictError): - APIDeploymentPhase(ctx).run(MigrationReport()) + APIDeploymentPhase(ctx).run(CloneReport()) def test_extra_source_keys_log_warning_not_failure(caplog): @@ -173,9 +173,9 @@ def test_extra_source_keys_log_warning_not_failure(caplog): ctx = _ctx(src, tgt, remap=remap) with caplog.at_level( - logging.WARNING, logger="unstract.migration.phases.api_deployment" + logging.WARNING, logger="unstract.clone.phases.api_deployment" ): - result = APIDeploymentPhase(ctx).run(MigrationReport()) + result = APIDeploymentPhase(ctx).run(CloneReport()) assert result.created == 1 assert result.failed == 0 diff --git a/tests/migration/test_base_helpers.py b/tests/clone/test_base_helpers.py similarity index 92% rename from tests/migration/test_base_helpers.py rename to tests/clone/test_base_helpers.py index ada8b87..727affa 100644 --- a/tests/migration/test_base_helpers.py +++ b/tests/clone/test_base_helpers.py @@ -1,8 +1,8 @@ -"""Tests for ``unstract.migration.phases.base`` helpers.""" +"""Tests for ``unstract.clone.phases.base`` helpers.""" from __future__ import annotations -from unstract.migration.phases.base import SERVER_MANAGED, build_post_payload +from unstract.clone.phases.base import SERVER_MANAGED, build_post_payload def test_preserves_false_and_zero_values(): diff --git a/tests/migration/test_connector_phase.py b/tests/clone/test_connector_phase.py similarity index 90% rename from tests/migration/test_connector_phase.py rename to tests/clone/test_connector_phase.py index a72a098..4a3c413 100644 --- a/tests/migration/test_connector_phase.py +++ b/tests/clone/test_connector_phase.py @@ -9,14 +9,14 @@ import pytest -from unstract.migration.context import ( - MigrationContext, - MigrationOptions, +from unstract.clone.context import ( + CloneContext, + CloneOptions, RemapTable, ) -from unstract.migration.exceptions import NameConflictError -from unstract.migration.phases.connector import ConnectorPhase -from unstract.migration.report import MigrationReport +from unstract.clone.exceptions import NameConflictError +from unstract.clone.phases.connector import ConnectorPhase +from unstract.clone.report import CloneReport class FakeClient: @@ -76,10 +76,10 @@ def _src(id_, name, catalog_id="postgres|abc", ctype="INPUT"): def _ctx(source, target, **opt_overrides): - return MigrationContext( + return CloneContext( source=source, target=target, - options=MigrationOptions(**opt_overrides), + options=CloneOptions(**opt_overrides), remap=RemapTable(), ) @@ -88,7 +88,7 @@ def test_happy_path_creates_all_and_records_remap(): src = FakeClient([_src("src-a", "Prod PG"), _src("src-b", "Stg S3", "s3|xyz")]) tgt = FakeClient() ctx = _ctx(src, tgt) - report = MigrationReport() + report = CloneReport() result = ConnectorPhase(ctx).run(report) @@ -108,7 +108,7 @@ def test_redacted_metadata_connector_skipped(): src = FakeClient([redacted]) tgt = FakeClient() ctx = _ctx(src, tgt) - report = MigrationReport() + report = CloneReport() result = ConnectorPhase(ctx).run(report) @@ -132,7 +132,7 @@ def test_idempotency_zero_creates_on_rerun(): ] ) ctx = _ctx(src, tgt, on_name_conflict="adopt") - report = MigrationReport() + report = CloneReport() result = ConnectorPhase(ctx).run(report) @@ -146,7 +146,7 @@ def test_dry_run_makes_no_posts(): src = FakeClient([_src("src-a", "Prod PG")]) tgt = FakeClient() ctx = _ctx(src, tgt, dry_run=True) - report = MigrationReport() + report = CloneReport() result = ConnectorPhase(ctx).run(report) @@ -169,7 +169,7 @@ def test_abort_on_name_conflict_raises(): ] ) ctx = _ctx(src, tgt, on_name_conflict="abort") - report = MigrationReport() + report = CloneReport() with pytest.raises(NameConflictError): ConnectorPhase(ctx).run(report) diff --git a/tests/migration/test_custom_tool_phase.py b/tests/clone/test_custom_tool_phase.py similarity index 93% rename from tests/migration/test_custom_tool_phase.py rename to tests/clone/test_custom_tool_phase.py index ec0e8cc..ad1dfd2 100644 --- a/tests/migration/test_custom_tool_phase.py +++ b/tests/clone/test_custom_tool_phase.py @@ -16,14 +16,14 @@ import pytest -from unstract.migration.context import ( - MigrationContext, - MigrationOptions, +from unstract.clone.context import ( + CloneContext, + CloneOptions, RemapTable, ) -from unstract.migration.exceptions import NameConflictError -from unstract.migration.phases.custom_tool import CustomToolPhase -from unstract.migration.report import MigrationReport +from unstract.clone.exceptions import NameConflictError +from unstract.clone.phases.custom_tool import CustomToolPhase +from unstract.clone.report import CloneReport ADAPTER_NAMES = { @@ -123,11 +123,11 @@ def export_custom_tool(self, tool_id: str, *, force: bool = True) -> None: ) -def _ctx(source, target, *, remap=None, **opt_overrides) -> MigrationContext: - return MigrationContext( +def _ctx(source, target, *, remap=None, **opt_overrides) -> CloneContext: + return CloneContext( source=source, target=target, - options=MigrationOptions(**opt_overrides), + options=CloneOptions(**opt_overrides), remap=remap or RemapTable(), ) @@ -204,7 +204,7 @@ def test_fresh_imports_with_name_resolved_adapter_ids_and_records_registry(): _seed_target_adapters(tgt) ctx = _ctx(src, tgt) - result = CustomToolPhase(ctx).run(MigrationReport()) + result = CustomToolPhase(ctx).run(CloneReport()) assert result.created == 1 assert result.failed == 0 @@ -237,7 +237,7 @@ def test_nested_adapter_dict_also_resolves(): _seed_target_adapters(tgt) ctx = _ctx(src, tgt) - CustomToolPhase(ctx).run(MigrationReport()) + CustomToolPhase(ctx).run(CloneReport()) _, adapter_ids = tgt.import_calls[0] assert adapter_ids["llm_adapter_id"] == TGT_ADAPTER_IDS["gpt4"] @@ -253,7 +253,7 @@ def test_adopt_path_calls_sync_prompts_and_skips_import(): _seed_target_adapters(tgt) ctx = _ctx(src, tgt) - result = CustomToolPhase(ctx).run(MigrationReport()) + result = CustomToolPhase(ctx).run(CloneReport()) assert result.adopted == 1 assert result.created == 0 @@ -280,7 +280,7 @@ def test_abort_on_name_conflict_raises(): ctx = _ctx(src, tgt, on_name_conflict="abort") with pytest.raises(NameConflictError): - CustomToolPhase(ctx).run(MigrationReport()) + CustomToolPhase(ctx).run(CloneReport()) assert tgt.sync_calls == [] assert tgt.import_calls == [] @@ -293,7 +293,7 @@ def test_dry_run_makes_no_writes(): _seed_target_adapters(tgt) ctx = _ctx(src, tgt, dry_run=True) - result = CustomToolPhase(ctx).run(MigrationReport()) + result = CustomToolPhase(ctx).run(CloneReport()) assert result.skipped == 1 assert tgt.import_calls == [] @@ -310,7 +310,7 @@ def test_missing_target_adapter_fails_tool_cleanly(): tgt.adapters_by_name[name] = {"id": TGT_ADAPTER_IDS[name], "adapter_name": name} ctx = _ctx(src, tgt) - result = CustomToolPhase(ctx).run(MigrationReport()) + result = CustomToolPhase(ctx).run(CloneReport()) assert result.failed == 1 assert tgt.import_calls == [] diff --git a/tests/migration/test_files_phase.py b/tests/clone/test_files_phase.py similarity index 94% rename from tests/migration/test_files_phase.py rename to tests/clone/test_files_phase.py index 43744ae..1392ca7 100644 --- a/tests/migration/test_files_phase.py +++ b/tests/clone/test_files_phase.py @@ -19,15 +19,15 @@ import pytest -from unstract.migration.context import ( - MigrationContext, - MigrationOptions, +from unstract.clone.context import ( + CloneContext, + CloneOptions, OrgEndpoint, RemapTable, ) -from unstract.migration.exceptions import PlatformAPIError -from unstract.migration.phases.files import FilesPhase -from unstract.migration.report import MigrationReport +from unstract.clone.exceptions import PlatformAPIError +from unstract.clone.phases.files import FilesPhase +from unstract.clone.report import CloneReport SRC_ENDPOINT = OrgEndpoint( @@ -125,12 +125,12 @@ def update_custom_tool(self, tool_id: str, body: dict) -> dict: def _ctx(src: FakeClient, tgt: FakeClient, *, remap: RemapTable | None = None, - **opts) -> MigrationContext: + **opts) -> CloneContext: remap = remap or RemapTable() - return MigrationContext( + return CloneContext( source=src, target=tgt, - options=MigrationOptions(**opts), + options=CloneOptions(**opts), remap=remap, ) @@ -160,7 +160,7 @@ def test_happy_path_uploads_pdf_and_text(): remap = RemapTable() remap.record("custom_tool", "src-1", "tgt-1") ctx = _ctx(src, tgt, remap=remap) - report = MigrationReport() + report = CloneReport() result = FilesPhase(ctx).run(report) @@ -190,7 +190,7 @@ def test_target_filename_present_is_skipped_no_download(): remap = RemapTable() remap.record("custom_tool", "src-1", "tgt-1") ctx = _ctx(src, tgt, remap=remap) - report = MigrationReport() + report = CloneReport() result = FilesPhase(ctx).run(report) @@ -214,7 +214,7 @@ def test_oversize_file_is_recorded_and_siblings_continue(): remap = RemapTable() remap.record("custom_tool", "src-1", "tgt-1") ctx = _ctx(src, tgt, remap=remap, max_file_size=10) - report = MigrationReport() + report = CloneReport() result = FilesPhase(ctx).run(report) @@ -242,7 +242,7 @@ def test_unsupported_mime_is_recorded_not_uploaded(): remap = RemapTable() remap.record("custom_tool", "src-1", "tgt-1") ctx = _ctx(src, tgt, remap=remap) - report = MigrationReport() + report = CloneReport() result = FilesPhase(ctx).run(report) @@ -263,7 +263,7 @@ def test_skip_strategy_emits_skipped_files_no_traffic(): remap = RemapTable() remap.record("custom_tool", "src-1", "tgt-1") ctx = _ctx(src, tgt, remap=remap, file_strategy="skip") - report = MigrationReport() + report = CloneReport() result = FilesPhase(ctx).run(report) @@ -286,7 +286,7 @@ def test_dry_run_makes_no_writes_even_for_missing_files(): remap = RemapTable() remap.record("custom_tool", "src-1", "tgt-1") ctx = _ctx(src, tgt, remap=remap, dry_run=True) - report = MigrationReport() + report = CloneReport() result = FilesPhase(ctx).run(report) @@ -310,8 +310,8 @@ def test_transient_503_is_retried_then_succeeds(monkeypatch): remap.record("custom_tool", "src-1", "tgt-1") ctx = _ctx(src, tgt, remap=remap) # Strip the backoff sleep so the test stays fast. - monkeypatch.setattr("unstract.migration.phases.files.time.sleep", lambda *_: None) - report = MigrationReport() + monkeypatch.setattr("unstract.clone.phases.files.time.sleep", lambda *_: None) + report = CloneReport() result = FilesPhase(ctx).run(report) @@ -323,7 +323,7 @@ def test_no_custom_tool_remap_is_noop(): src = FakeClient(endpoint=SRC_ENDPOINT) tgt = FakeClient(endpoint=TGT_ENDPOINT) ctx = _ctx(src, tgt) - report = MigrationReport() + report = CloneReport() result = FilesPhase(ctx).run(report) @@ -350,7 +350,7 @@ def test_source_list_failure_isolates_to_that_tool(): remap.record("custom_tool", "src-1", "tgt-1") remap.record("custom_tool", "src-2", "tgt-2") ctx = _ctx(src, tgt, remap=remap) - report = MigrationReport() + report = CloneReport() result = FilesPhase(ctx).run(report) @@ -372,7 +372,7 @@ def test_upload_failure_records_failed_files_entry(): remap = RemapTable() remap.record("custom_tool", "src-1", "tgt-1") ctx = _ctx(src, tgt, remap=remap) - report = MigrationReport() + report = CloneReport() result = FilesPhase(ctx).run(report) @@ -401,7 +401,7 @@ def test_text_mimes_round_trip_as_utf8(mime, raw): remap = RemapTable() remap.record("custom_tool", "src-1", "tgt-1") ctx = _ctx(src, tgt, remap=remap) - report = MigrationReport() + report = CloneReport() result = FilesPhase(ctx).run(report) @@ -429,7 +429,7 @@ def test_default_doc_mirrors_source_selection_by_filename(): remap.record("custom_tool", "src-1", "tgt-1") ctx = _ctx(src, tgt, remap=remap) - FilesPhase(ctx).run(MigrationReport()) + FilesPhase(ctx).run(CloneReport()) # Target's CustomTool.output now points at b.pdf's new target doc id. tgt_tool = next(t for t in tgt._tools if t["tool_id"] == "tgt-1") @@ -453,7 +453,7 @@ def test_default_doc_falls_back_to_first_when_source_has_none(): remap.record("custom_tool", "src-1", "tgt-1") ctx = _ctx(src, tgt, remap=remap) - FilesPhase(ctx).run(MigrationReport()) + FilesPhase(ctx).run(CloneReport()) tgt_tool = next(t for t in tgt._tools if t["tool_id"] == "tgt-1") a_upload = next(d for d in tgt._documents["tgt-1"] if d["document_name"] == "a.pdf") @@ -478,7 +478,7 @@ def test_default_doc_preserves_existing_target_choice(): remap.record("custom_tool", "src-1", "tgt-1") ctx = _ctx(src, tgt, remap=remap) - FilesPhase(ctx).run(MigrationReport()) + FilesPhase(ctx).run(CloneReport()) tgt_tool = next(t for t in tgt._tools if t["tool_id"] == "tgt-1") assert tgt_tool["output"] == "operator-pick" diff --git a/tests/migration/test_pipeline_phase.py b/tests/clone/test_pipeline_phase.py similarity index 87% rename from tests/migration/test_pipeline_phase.py rename to tests/clone/test_pipeline_phase.py index a064628..982a960 100644 --- a/tests/migration/test_pipeline_phase.py +++ b/tests/clone/test_pipeline_phase.py @@ -2,7 +2,7 @@ Coverage: - happy path: source ETL/TASK pipelines created with workflow FK remapped. -- DEFAULT and APP types are skipped (out of migration scope). +- DEFAULT and APP types are skipped (out of clone scope). - adopt path on name conflict. - skipped when workflow remap missing. - dry-run is a no-op. @@ -16,14 +16,14 @@ import pytest -from unstract.migration.context import ( - MigrationContext, - MigrationOptions, +from unstract.clone.context import ( + CloneContext, + CloneOptions, RemapTable, ) -from unstract.migration.exceptions import NameConflictError -from unstract.migration.phases.pipeline import PipelinePhase -from unstract.migration.report import MigrationReport +from unstract.clone.exceptions import NameConflictError +from unstract.clone.phases.pipeline import PipelinePhase +from unstract.clone.report import CloneReport PIPELINE_POST_SCHEMA = frozenset( @@ -102,10 +102,10 @@ def _src_pipeline( def _ctx(source, target, *, remap=None, **opt_overrides): - return MigrationContext( + return CloneContext( source=source, target=target, - options=MigrationOptions(**opt_overrides), + options=CloneOptions(**opt_overrides), remap=remap or RemapTable(), ) @@ -119,7 +119,7 @@ def test_happy_path_creates_pipeline_with_remapped_workflow(): remap.record("workflow", "wf-src-1", "wf-tgt-1") ctx = _ctx(src, tgt, remap=remap) - result = PipelinePhase(ctx).run(MigrationReport()) + result = PipelinePhase(ctx).run(CloneReport()) assert result.created == 1 assert result.failed == 0 @@ -142,7 +142,7 @@ def test_default_and_app_pipeline_types_are_skipped(): remap.record("workflow", "wf-src-1", "wf-tgt-1") ctx = _ctx(src, tgt, remap=remap) - result = PipelinePhase(ctx).run(MigrationReport()) + result = PipelinePhase(ctx).run(CloneReport()) assert result.created == 1 assert len(tgt.posts) == 1 @@ -158,7 +158,7 @@ def test_adopts_existing_pipeline_by_name(): remap.record("workflow", "wf-src-1", "wf-tgt-1") ctx = _ctx(src, tgt, remap=remap) - result = PipelinePhase(ctx).run(MigrationReport()) + result = PipelinePhase(ctx).run(CloneReport()) assert result.adopted == 1 assert result.created == 0 @@ -171,7 +171,7 @@ def test_skipped_when_workflow_remap_missing(): tgt = FakeClient() ctx = _ctx(src, tgt) # No workflow remap. - result = PipelinePhase(ctx).run(MigrationReport()) + result = PipelinePhase(ctx).run(CloneReport()) assert result.skipped == 1 assert result.failed == 0 @@ -185,7 +185,7 @@ def test_dry_run_makes_no_writes(): remap.record("workflow", "wf-src-1", "wf-tgt-1") ctx = _ctx(src, tgt, remap=remap, dry_run=True) - result = PipelinePhase(ctx).run(MigrationReport()) + result = PipelinePhase(ctx).run(CloneReport()) assert result.skipped == 1 assert tgt.posts == [] @@ -199,7 +199,7 @@ def test_abort_on_name_conflict_raises(): ctx = _ctx(src, tgt, remap=remap, on_name_conflict="abort") with pytest.raises(NameConflictError): - PipelinePhase(ctx).run(MigrationReport()) + PipelinePhase(ctx).run(CloneReport()) def test_extra_source_keys_log_warning_not_failure(caplog): @@ -214,8 +214,8 @@ def test_extra_source_keys_log_warning_not_failure(caplog): remap.record("workflow", "wf-src-1", "wf-tgt-1") ctx = _ctx(src, tgt, remap=remap) - with caplog.at_level(logging.WARNING, logger="unstract.migration.phases.pipeline"): - result = PipelinePhase(ctx).run(MigrationReport()) + with caplog.at_level(logging.WARNING, logger="unstract.clone.phases.pipeline"): + result = PipelinePhase(ctx).run(CloneReport()) assert result.created == 1 assert result.failed == 0 diff --git a/tests/migration/test_remap_table.py b/tests/clone/test_remap_table.py similarity index 95% rename from tests/migration/test_remap_table.py rename to tests/clone/test_remap_table.py index 3045326..6a13754 100644 --- a/tests/migration/test_remap_table.py +++ b/tests/clone/test_remap_table.py @@ -1,6 +1,6 @@ """Tests for ``RemapTable``.""" -from unstract.migration.context import RemapTable +from unstract.clone.context import RemapTable def test_record_and_resolve_per_entity(): diff --git a/tests/migration/test_tag_phase.py b/tests/clone/test_tag_phase.py similarity index 85% rename from tests/migration/test_tag_phase.py rename to tests/clone/test_tag_phase.py index c9fc3e9..f6086a9 100644 --- a/tests/migration/test_tag_phase.py +++ b/tests/clone/test_tag_phase.py @@ -8,14 +8,14 @@ import pytest -from unstract.migration.context import ( - MigrationContext, - MigrationOptions, +from unstract.clone.context import ( + CloneContext, + CloneOptions, RemapTable, ) -from unstract.migration.exceptions import NameConflictError -from unstract.migration.phases.tag import TagPhase -from unstract.migration.report import MigrationReport +from unstract.clone.exceptions import NameConflictError +from unstract.clone.phases.tag import TagPhase +from unstract.clone.report import CloneReport class FakeClient: @@ -49,10 +49,10 @@ def _src(id_, name): def _ctx(source, target, **opt_overrides): - return MigrationContext( + return CloneContext( source=source, target=target, - options=MigrationOptions(**opt_overrides), + options=CloneOptions(**opt_overrides), remap=RemapTable(), ) @@ -61,7 +61,7 @@ def test_happy_path_creates_all_and_records_remap(): src = FakeClient([_src("src-a", "billing"), _src("src-b", "finance")]) tgt = FakeClient() ctx = _ctx(src, tgt) - report = MigrationReport() + report = CloneReport() result = TagPhase(ctx).run(report) @@ -76,7 +76,7 @@ def test_idempotency_zero_creates_on_rerun(): src = FakeClient([_src("src-a", "billing")]) tgt = FakeClient([{"id": "preexisting", "name": "billing", "description": "x"}]) ctx = _ctx(src, tgt, on_name_conflict="adopt") - report = MigrationReport() + report = CloneReport() result = TagPhase(ctx).run(report) @@ -90,7 +90,7 @@ def test_dry_run_makes_no_posts(): src = FakeClient([_src("src-a", "billing")]) tgt = FakeClient() ctx = _ctx(src, tgt, dry_run=True) - report = MigrationReport() + report = CloneReport() result = TagPhase(ctx).run(report) @@ -103,7 +103,7 @@ def test_abort_on_name_conflict_raises(): src = FakeClient([_src("src-a", "billing")]) tgt = FakeClient([{"id": "preexisting", "name": "billing", "description": "x"}]) ctx = _ctx(src, tgt, on_name_conflict="abort") - report = MigrationReport() + report = CloneReport() with pytest.raises(NameConflictError): TagPhase(ctx).run(report) diff --git a/tests/migration/test_tool_instance_phase.py b/tests/clone/test_tool_instance_phase.py similarity index 79% rename from tests/migration/test_tool_instance_phase.py rename to tests/clone/test_tool_instance_phase.py index c43c311..eb47405 100644 --- a/tests/migration/test_tool_instance_phase.py +++ b/tests/clone/test_tool_instance_phase.py @@ -1,7 +1,7 @@ """Tests for ``ToolInstancePhase``. ToolInstance is unique among phases: -- The source list of "things to migrate" comes from the workflow remap +- The source list of "things to clone" comes from the workflow remap table, not a top-level entity list. - Create is a two-step dance (POST bare, PATCH metadata) because the backend rebuilds metadata from defaults on POST. @@ -9,13 +9,13 @@ from __future__ import annotations -from unstract.migration.context import ( - MigrationContext, - MigrationOptions, +from unstract.clone.context import ( + CloneContext, + CloneOptions, RemapTable, ) -from unstract.migration.phases.tool_instance import ToolInstancePhase -from unstract.migration.report import MigrationReport +from unstract.clone.phases.tool_instance import ToolInstancePhase +from unstract.clone.report import CloneReport class FakeClient: @@ -56,10 +56,10 @@ def update_tool_instance_metadata( def _ctx(source, target, *, remap=None, **opt_overrides): - return MigrationContext( + return CloneContext( source=source, target=target, - options=MigrationOptions(**opt_overrides), + options=CloneOptions(**opt_overrides), remap=remap or RemapTable(), ) @@ -92,13 +92,21 @@ def test_happy_path_creates_instance_then_patches_metadata(): src.instances[SRC_WF] = [ _src_ti( "src-ti-1", SRC_WF, SRC_REG, - {"llm": "My OpenAI", "embedding": "MyEmb", "tenant_id": "src-org"}, + { + "llm": "My OpenAI", + "embedding": "MyEmb", + # Identity fields that the backend populated server-side + # at source create time — must NOT cross the org boundary. + "tenant_id": "src-org", + "prompt_registry_id": "src-registry-uuid", + "tool_instance_id": "src-ti-1-pk", + }, ) ] tgt = FakeClient() ctx = _ctx(src, tgt, remap=_seed_remap()) - result = ToolInstancePhase(ctx).run(MigrationReport()) + result = ToolInstancePhase(ctx).run(CloneReport()) assert result.created == 1 assert result.failed == 0 @@ -106,13 +114,12 @@ def test_happy_path_creates_instance_then_patches_metadata(): posted = tgt.create_calls[0] assert posted["workflow_id"] == TGT_WF assert posted["tool_id"] == TGT_REG - # PATCH carries the source metadata verbatim (backend handles name→UUID). + # PATCH carries the source settings but never the source-internal + # identity fields — the target row already has its own. assert len(tgt.patch_calls) == 1 patched_id, patched_metadata = tgt.patch_calls[0] assert patched_id == posted["id"] - assert patched_metadata == { - "llm": "My OpenAI", "embedding": "MyEmb", "tenant_id": "src-org", - } + assert patched_metadata == {"llm": "My OpenAI", "embedding": "MyEmb"} assert ctx.remap.resolve("tool_instance", "src-ti-1") == posted["id"] @@ -125,7 +132,7 @@ def test_skip_when_registry_remap_missing(): # No prompt_studio_registry remap entry → SDK must skip. ctx = _ctx(src, tgt, remap=remap) - result = ToolInstancePhase(ctx).run(MigrationReport()) + result = ToolInstancePhase(ctx).run(CloneReport()) assert result.skipped == 1 assert result.created == 0 @@ -142,7 +149,7 @@ def test_adopt_existing_target_instance_and_repatch_metadata(): ] ctx = _ctx(src, tgt, remap=_seed_remap()) - result = ToolInstancePhase(ctx).run(MigrationReport()) + result = ToolInstancePhase(ctx).run(CloneReport()) assert result.adopted == 1 assert result.created == 0 @@ -157,7 +164,7 @@ def test_no_op_when_no_workflows_in_remap(): tgt = FakeClient() ctx = _ctx(src, tgt, remap=RemapTable()) - result = ToolInstancePhase(ctx).run(MigrationReport()) + result = ToolInstancePhase(ctx).run(CloneReport()) assert result.created == 0 assert result.skipped == 0 @@ -170,7 +177,7 @@ def test_dry_run_does_not_create_or_patch(): tgt = FakeClient() ctx = _ctx(src, tgt, remap=_seed_remap(), dry_run=True) - result = ToolInstancePhase(ctx).run(MigrationReport()) + result = ToolInstancePhase(ctx).run(CloneReport()) assert result.skipped == 1 assert tgt.create_calls == [] diff --git a/tests/migration/test_walker.py b/tests/clone/test_walker.py similarity index 93% rename from tests/migration/test_walker.py rename to tests/clone/test_walker.py index 44107bc..5ae0301 100644 --- a/tests/migration/test_walker.py +++ b/tests/clone/test_walker.py @@ -1,7 +1,7 @@ """Tests for ``remap_uuids``.""" -from unstract.migration.context import RemapTable -from unstract.migration.walker import remap_uuids +from unstract.clone.context import RemapTable +from unstract.clone.walker import remap_uuids SRC_A = "11111111-1111-1111-1111-111111111111" TGT_A = "22222222-2222-2222-2222-222222222222" diff --git a/tests/migration/test_workflow_endpoint_phase.py b/tests/clone/test_workflow_endpoint_phase.py similarity index 90% rename from tests/migration/test_workflow_endpoint_phase.py rename to tests/clone/test_workflow_endpoint_phase.py index 161ce92..1df2837 100644 --- a/tests/migration/test_workflow_endpoint_phase.py +++ b/tests/clone/test_workflow_endpoint_phase.py @@ -10,13 +10,13 @@ from __future__ import annotations -from unstract.migration.context import ( - MigrationContext, - MigrationOptions, +from unstract.clone.context import ( + CloneContext, + CloneOptions, RemapTable, ) -from unstract.migration.phases.workflow_endpoint import WorkflowEndpointPhase -from unstract.migration.report import MigrationReport +from unstract.clone.phases.workflow_endpoint import WorkflowEndpointPhase +from unstract.clone.report import CloneReport class FakeClient: @@ -44,10 +44,10 @@ def update_workflow_endpoint( def _ctx(source, target, *, remap=None, **opt_overrides): - return MigrationContext( + return CloneContext( source=source, target=target, - options=MigrationOptions(**opt_overrides), + options=CloneOptions(**opt_overrides), remap=remap or RemapTable(), ) @@ -110,7 +110,7 @@ def test_pairs_endpoints_by_type_and_remaps_connector(): ] ctx = _ctx(src, tgt, remap=_seed_remap()) - result = WorkflowEndpointPhase(ctx).run(MigrationReport()) + result = WorkflowEndpointPhase(ctx).run(CloneReport()) assert result.created == 2 assert result.failed == 0 @@ -146,7 +146,7 @@ def test_endpoint_without_source_connector_patches_with_null(): tgt.endpoints[TGT_WF] = [_tgt_endpoint("tgt-ep-source", "SOURCE")] ctx = _ctx(src, tgt, remap=_seed_remap()) - result = WorkflowEndpointPhase(ctx).run(MigrationReport()) + result = WorkflowEndpointPhase(ctx).run(CloneReport()) assert result.created == 1 assert len(tgt.patch_calls) == 1 @@ -173,7 +173,7 @@ def test_unknown_connector_uuid_skips_endpoint_and_flags_error(): tgt.endpoints[TGT_WF] = [_tgt_endpoint("tgt-ep-source", "SOURCE")] ctx = _ctx(src, tgt, remap=_seed_remap()) - result = WorkflowEndpointPhase(ctx).run(MigrationReport()) + result = WorkflowEndpointPhase(ctx).run(CloneReport()) assert result.created == 0 assert result.skipped == 1 @@ -190,7 +190,7 @@ def test_missing_target_endpoint_fails_loudly(): tgt.endpoints[TGT_WF] = [] # No endpoints — anomaly. ctx = _ctx(src, tgt, remap=_seed_remap()) - result = WorkflowEndpointPhase(ctx).run(MigrationReport()) + result = WorkflowEndpointPhase(ctx).run(CloneReport()) assert result.failed == 1 assert tgt.patch_calls == [] @@ -205,7 +205,7 @@ def test_dry_run_makes_no_patches(): tgt.endpoints[TGT_WF] = [_tgt_endpoint("tgt-ep-source", "SOURCE")] ctx = _ctx(src, tgt, remap=_seed_remap(), dry_run=True) - result = WorkflowEndpointPhase(ctx).run(MigrationReport()) + result = WorkflowEndpointPhase(ctx).run(CloneReport()) assert result.skipped == 1 assert tgt.patch_calls == [] @@ -216,7 +216,7 @@ def test_no_workflows_in_remap_is_noop(): tgt = FakeClient() ctx = _ctx(src, tgt, remap=RemapTable()) - result = WorkflowEndpointPhase(ctx).run(MigrationReport()) + result = WorkflowEndpointPhase(ctx).run(CloneReport()) assert result.created == 0 assert tgt.patch_calls == [] diff --git a/tests/migration/test_workflow_phase.py b/tests/clone/test_workflow_phase.py similarity index 88% rename from tests/migration/test_workflow_phase.py rename to tests/clone/test_workflow_phase.py index 4583d75..638ed83 100644 --- a/tests/migration/test_workflow_phase.py +++ b/tests/clone/test_workflow_phase.py @@ -12,14 +12,14 @@ import pytest -from unstract.migration.context import ( - MigrationContext, - MigrationOptions, +from unstract.clone.context import ( + CloneContext, + CloneOptions, RemapTable, ) -from unstract.migration.exceptions import NameConflictError -from unstract.migration.phases.workflow import WorkflowPhase -from unstract.migration.report import MigrationReport +from unstract.clone.exceptions import NameConflictError +from unstract.clone.phases.workflow import WorkflowPhase +from unstract.clone.report import CloneReport WORKFLOW_POST_SCHEMA = frozenset( @@ -77,10 +77,10 @@ def _src(id_, name, *, source_settings=None, destination_settings=None): def _ctx(source, target, *, remap=None, **opt_overrides): - return MigrationContext( + return CloneContext( source=source, target=target, - options=MigrationOptions(**opt_overrides), + options=CloneOptions(**opt_overrides), remap=remap or RemapTable(), ) @@ -103,7 +103,7 @@ def test_happy_path_creates_workflow_and_remaps_connector_uuids(): remap.record("connector", src_conn, tgt_conn) ctx = _ctx(src, tgt, remap=remap) - result = WorkflowPhase(ctx).run(MigrationReport()) + result = WorkflowPhase(ctx).run(CloneReport()) assert result.created == 1 assert result.failed == 0 @@ -125,7 +125,7 @@ def test_idempotent_rerun_adopts_existing_workflow(): ) ctx = _ctx(src, tgt, on_name_conflict="adopt") - result = WorkflowPhase(ctx).run(MigrationReport()) + result = WorkflowPhase(ctx).run(CloneReport()) assert result.adopted == 1 assert result.created == 0 @@ -138,7 +138,7 @@ def test_dry_run_creates_nothing(): tgt = FakeClient() ctx = _ctx(src, tgt, dry_run=True) - result = WorkflowPhase(ctx).run(MigrationReport()) + result = WorkflowPhase(ctx).run(CloneReport()) assert result.skipped == 1 assert tgt.posts == [] @@ -152,4 +152,4 @@ def test_abort_on_name_conflict_raises(): ctx = _ctx(src, tgt, on_name_conflict="abort") with pytest.raises(NameConflictError): - WorkflowPhase(ctx).run(MigrationReport()) + WorkflowPhase(ctx).run(CloneReport()) diff --git a/uv.lock b/uv.lock index 57e8a81..8710285 100644 --- a/uv.lock +++ b/uv.lock @@ -804,7 +804,7 @@ dependencies = [ ] [package.optional-dependencies] -migration = [ +clone = [ { name = "click" }, { name = "rich" }, ] @@ -841,12 +841,12 @@ test = [ [package.metadata] requires-dist = [ - { name = "click", marker = "extra == 'migration'", specifier = ">=8.1" }, + { name = "click", marker = "extra == 'clone'", specifier = ">=8.1" }, { name = "requests", specifier = ">=2.32.3" }, - { name = "rich", marker = "extra == 'migration'", specifier = ">=13.7" }, + { name = "rich", marker = "extra == 'clone'", specifier = ">=13.7" }, { name = "tenacity", specifier = ">=8.2.0" }, ] -provides-extras = ["migration"] +provides-extras = ["clone"] [package.metadata.requires-dev] dev = [ From da0724794df552d8e7a4c5ef7d8f4ddc53e90879 Mon Sep 17 00:00:00 2001 From: Chandrasekharan M Date: Tue, 26 May 2026 11:31:00 +0530 Subject: [PATCH 21/25] docs(clone): top-of-README note on users + install instructions --- src/unstract/clone/README.md | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/src/unstract/clone/README.md b/src/unstract/clone/README.md index b952df4..8cce56f 100644 --- a/src/unstract/clone/README.md +++ b/src/unstract/clone/README.md @@ -1,5 +1,12 @@ # Cloning Organizations +> [!NOTE] +> **Users are not cloned.** Two reasons: +> - The same user may not need access in every environment. +> - The same user may hold different roles across environments. +> +> **Groups _will_ be cloned** (upcoming — not yet implemented). Once available, an admin can add the right users to each group per environment. + Clone an Unstract organization's configured resources into another organization (same deployment or different). Useful for environment promotion (DEV → QA → PROD) and for spinning up a fresh org from a known-good baseline. Cloned resources: adapters, connectors, custom tools, prompts, profiles, workflows, tool instances, workflow endpoints, tags, API deployments, pipelines, and Prompt Studio document files. The source org is left untouched. @@ -7,6 +14,16 @@ Cloned resources: adapters, connectors, custom tools, prompts, profiles, workflo > **Full documentation, behavior notes, CLI reference, and sample report:** > https://docs.unstract.com/unstract/unstract_platform/api_documentation/versions/v1-org-cloning/ +## Install + +From a clone of this repository: + +```bash +uv sync --all-extras +``` + +This pulls in the `clone` extra (`click`, `rich`) needed by the CLI. + ## Quickstart ```bash From 836962cc818def5fd6a53703cd222e7ed6cfc658 Mon Sep 17 00:00:00 2001 From: Chandrasekharan M Date: Tue, 26 May 2026 12:50:15 +0530 Subject: [PATCH 22/25] docs(clone): update public docs URL to /cloning-orgs/ --- src/unstract/clone/README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/unstract/clone/README.md b/src/unstract/clone/README.md index 8cce56f..c155641 100644 --- a/src/unstract/clone/README.md +++ b/src/unstract/clone/README.md @@ -12,7 +12,7 @@ Clone an Unstract organization's configured resources into another organization Cloned resources: adapters, connectors, custom tools, prompts, profiles, workflows, tool instances, workflow endpoints, tags, API deployments, pipelines, and Prompt Studio document files. The source org is left untouched. > **Full documentation, behavior notes, CLI reference, and sample report:** -> https://docs.unstract.com/unstract/unstract_platform/api_documentation/versions/v1-org-cloning/ +> https://docs.unstract.com/unstract/unstract_platform/api_documentation/versions/cloning-orgs/ ## Install @@ -52,4 +52,4 @@ The Prompt Studio document corpus is the only resource type with bytes on disk. > [!WARNING] > Run clones during low-activity windows. Concurrent uploads to the source org during a clone can create duplicate file records on the target. -See the [public docs](https://docs.unstract.com/unstract/unstract_platform/api_documentation/versions/v1-org-cloning/) for the full flag list, behavioral notes, and the format of the end-of-run report. +See the [public docs](https://docs.unstract.com/unstract/unstract_platform/api_documentation/versions/cloning-orgs/) for the full flag list, behavioral notes, and the format of the end-of-run report. From 9391ce392fd710f826232bcc73e7ebcb304774b0 Mon Sep 17 00:00:00 2001 From: Chandrasekharan M Date: Tue, 26 May 2026 17:44:52 +0530 Subject: [PATCH 23/25] fix(clone): address PR #15 review feedback (P1s + quick wins) P1 fixes: - tool_instance: dry-run no longer PATCHes adopted targets. - custom_tool: dry-run no longer republishes the registry on adopt. - custom_tool: registry remap-lookup failure now counts as failed. - custom_tool: phase-init failure returns `result`, not `None`. - files: malformed DM rows + unsupported-mime + oversize bump skipped with an error entry so the run no longer reports green when items needed manual attention. - files: tighten `_lookup_tool_name` exceptions to PlatformAPIError and transport errors. - workflow_endpoint: drop `connection_type or ""` coercion that could turn None into a blank DRF would reject; omit the key. - pipeline/api_deployment: do per-id GET before POST so list-only serializer fields don't get stripped from the create payload. P2/P3 quick wins: - PlatformClient: add `close()` + context manager; orchestrator closes both clients in a finally block. - pipeline/api_deployment: promote DEBUG to WARNING when the source key-list call fails (operator-facing). - cli: distinguish `--max-file-size 0` from unparseable; preserve 0. - files: drop the dead `docs/internal/...` reference. Tests: - New test_client.py + test_orchestrator.py + test_cli.py cover HTTP-layer + orchestrator paths that previously had no coverage. - Regression tests added to existing phase tests for each P1 fix. - 105 tests pass (was 75). --- src/unstract/clone/cli.py | 16 +- src/unstract/clone/client.py | 42 +++-- src/unstract/clone/orchestrator.py | 60 ++++--- src/unstract/clone/phases/api_deployment.py | 39 ++++- src/unstract/clone/phases/custom_tool.py | 55 +++--- src/unstract/clone/phases/files.py | 123 ++++++++++---- src/unstract/clone/phases/pipeline.py | 28 ++- src/unstract/clone/phases/tool_instance.py | 57 +++++-- .../clone/phases/workflow_endpoint.py | 36 ++-- tests/clone/test_api_deployment_phase.py | 11 +- tests/clone/test_cli.py | 129 ++++++++++++++ tests/clone/test_client.py | 145 ++++++++++++++++ tests/clone/test_custom_tool_phase.py | 35 +++- tests/clone/test_files_phase.py | 73 ++++++-- tests/clone/test_orchestrator.py | 159 ++++++++++++++++++ tests/clone/test_pipeline_phase.py | 59 ++++++- tests/clone/test_tool_instance_phase.py | 61 ++++++- tests/clone/test_workflow_endpoint_phase.py | 44 +++-- 18 files changed, 974 insertions(+), 198 deletions(-) create mode 100644 tests/clone/test_cli.py create mode 100644 tests/clone/test_client.py create mode 100644 tests/clone/test_orchestrator.py diff --git a/src/unstract/clone/cli.py b/src/unstract/clone/cli.py index 5ce15b8..6b7a0b4 100644 --- a/src/unstract/clone/cli.py +++ b/src/unstract/clone/cli.py @@ -70,7 +70,9 @@ def cli() -> None: @cli.command("clone") @click.option("--source-url", required=True, help="Base URL of the source deployment") -@click.option("--source-org", required=True, help="Source organization_id (slug in the URL path)") +@click.option( + "--source-org", required=True, help="Source organization_id (slug in the URL path)" +) @click.option( "--source-key", envvar="UNSTRACT_SRC_PLATFORM_KEY", @@ -78,14 +80,18 @@ def cli() -> None: help="Source admin's Platform API key (or env UNSTRACT_SRC_PLATFORM_KEY)", ) @click.option("--target-url", required=True, help="Base URL of the target deployment") -@click.option("--target-org", required=True, help="Target organization_id (slug in the URL path)") +@click.option( + "--target-org", required=True, help="Target organization_id (slug in the URL path)" +) @click.option( "--target-key", envvar="UNSTRACT_TGT_PLATFORM_KEY", required=True, help="Target admin's Platform API key (or env UNSTRACT_TGT_PLATFORM_KEY)", ) -@click.option("--dry-run", is_flag=True, help="Plan only — do not POST anything to target") +@click.option( + "--dry-run", is_flag=True, help="Plan only — do not POST anything to target" +) @click.option( "--include", default=None, @@ -161,7 +167,9 @@ def clone_cmd( on_name_conflict=on_name_conflict, verbose=verbose, file_strategy=effective_strategy, - max_file_size=cap_bytes or DEFAULT_MAX_FILE_SIZE, + # Distinguish "user said 0" (force every file to manual list) from + # an unparseable size — `_parse_size` raises in the latter case. + max_file_size=cap_bytes if cap_bytes is not None else DEFAULT_MAX_FILE_SIZE, ) source = OrgEndpoint( diff --git a/src/unstract/clone/client.py b/src/unstract/clone/client.py index 097394a..ad873da 100644 --- a/src/unstract/clone/client.py +++ b/src/unstract/clone/client.py @@ -27,7 +27,9 @@ class PlatformClient: """HTTP client scoped to a single org via its Platform API key.""" - def __init__(self, endpoint: OrgEndpoint, timeout: int = DEFAULT_TIMEOUT, verify: bool = True): + def __init__( + self, endpoint: OrgEndpoint, timeout: int = DEFAULT_TIMEOUT, verify: bool = True + ): self.endpoint = endpoint self.timeout = timeout self.verify = verify @@ -42,6 +44,16 @@ def __init__(self, endpoint: OrgEndpoint, timeout: int = DEFAULT_TIMEOUT, verify # Backend serializer is the single source of truth; we read it once. self._post_schema_cache: dict[str, frozenset[str]] = {} + def close(self) -> None: + """Release the underlying HTTP connection pool.""" + self._session.close() + + def __enter__(self) -> "PlatformClient": + return self + + def __exit__(self, *exc: Any) -> None: + self.close() + def _url(self, path: str) -> str: base = self.endpoint.base_url.rstrip("/") api_prefix = self.endpoint.api_path_prefix.strip("/") @@ -178,9 +190,7 @@ def get_custom_tool(self, tool_id: str) -> dict[str, Any]: """ return self._request("GET", f"prompt-studio/{tool_id}/") - def update_custom_tool( - self, tool_id: str, body: dict[str, Any] - ) -> dict[str, Any]: + def update_custom_tool(self, tool_id: str, body: dict[str, Any]) -> dict[str, Any]: """PATCH a prompt-studio project. Used to set ``output`` (the default doc id) after the files phase populates DM rows.""" return self._request("PATCH", f"prompt-studio/{tool_id}/", json=body) @@ -192,9 +202,7 @@ def list_profiles(self, tool_id: str) -> list[dict[str, Any]]: default profile's adapter UUIDs so they can be remapped to target adapter ids for ``import_project``. """ - result = self._request( - "GET", f"prompt-studio/prompt-studio-profile/{tool_id}/" - ) + result = self._request("GET", f"prompt-studio/prompt-studio-profile/{tool_id}/") return result if isinstance(result, list) else result.get("results", []) def export_project(self, tool_id: str) -> dict[str, Any]: @@ -226,9 +234,7 @@ def import_project( required to wire the profile; otherwise backend falls back to a profile without adapters and flags ``needs_adapter_config``. """ - tool_name = ( - export_data.get("tool_metadata", {}).get("tool_name") or "export" - ) + tool_name = export_data.get("tool_metadata", {}).get("tool_name") or "export" content = json_lib.dumps(export_data).encode() files = {"file": (f"{tool_name}.json", content, "application/json")} data: dict[str, Any] = {} @@ -281,9 +287,7 @@ def list_prompt_documents(self, tool_id: str) -> list[dict[str, Any]]: ) return result if isinstance(result, list) else result.get("results", []) - def download_prompt_file( - self, tool_id: str, document_id: str - ) -> dict[str, Any]: + def download_prompt_file(self, tool_id: str, document_id: str) -> dict[str, Any]: """GET a Prompt Studio document by tool + DM row id. ``fetch_contents_ide`` resolves the filename internally from the @@ -313,9 +317,7 @@ def upload_prompt_file( an IntegrityError → 500 on re-runs. """ files = {"file": (file_name, data, mime_type)} - return self._request( - "POST", f"prompt-studio/file/{tool_id}", files=files - ) + return self._request("POST", f"prompt-studio/file/{tool_id}", files=files) def export_custom_tool(self, tool_id: str, *, force: bool = True) -> Any: """Republish ``PromptStudioRegistry`` from the tool's current state. @@ -415,9 +417,7 @@ def list_workflow_endpoints( def update_workflow_endpoint( self, endpoint_id: str, payload: dict[str, Any] ) -> dict[str, Any]: - return self._request( - "PATCH", f"workflow/endpoint/{endpoint_id}/", json=payload - ) + return self._request("PATCH", f"workflow/endpoint/{endpoint_id}/", json=payload) # ----- pipelines (ETL / TASK) ----- @@ -478,9 +478,7 @@ def create_api_deployment(self, payload: dict[str, Any]) -> dict[str, Any]: def update_api_deployment( self, deployment_id: str, payload: dict[str, Any] ) -> dict[str, Any]: - return self._request( - "PATCH", f"api/deployment/{deployment_id}/", json=payload - ) + return self._request("PATCH", f"api/deployment/{deployment_id}/", json=payload) # ----- API keys (per pipeline / deployment) ----- diff --git a/src/unstract/clone/orchestrator.py b/src/unstract/clone/orchestrator.py index 2a21d54..552c0b0 100644 --- a/src/unstract/clone/orchestrator.py +++ b/src/unstract/clone/orchestrator.py @@ -29,7 +29,7 @@ WorkflowPhase, ) from unstract.clone.phases.base import Phase -from unstract.clone.report import Endpoint, CloneReport +from unstract.clone.report import CloneReport, Endpoint logger = logging.getLogger(__name__) @@ -64,29 +64,39 @@ def clone( setup errors or ``on_name_conflict='abort'`` collisions. """ opts = options or CloneOptions() - ctx = CloneContext( - source=PlatformClient(source), - target=PlatformClient(target), - options=opts, - ) - report = CloneReport( - source=Endpoint(base_url=source.base_url, organization_id=source.organization_id), - target=Endpoint(base_url=target.base_url, organization_id=target.organization_id), - ) + src_client = PlatformClient(source) + tgt_client = PlatformClient(target) + try: + ctx = CloneContext( + source=src_client, + target=tgt_client, + options=opts, + ) + report = CloneReport( + source=Endpoint( + base_url=source.base_url, organization_id=source.organization_id + ), + target=Endpoint( + base_url=target.base_url, organization_id=target.organization_id + ), + ) - for name, phase_cls in PHASES: - if not opts.includes(name): - report.skipped_phases.append(name) - logger.info("Phase '%s' skipped (excluded)", name) - continue - logger.info("=== Phase: %s ===", name) - try: - phase_cls(ctx).run(report) - except CloneError as e: - report.aborted = True - report.abort_reason = str(e) - logger.error("Phase '%s' aborted: %s", name, e) - break + for name, phase_cls in PHASES: + if not opts.includes(name): + report.skipped_phases.append(name) + logger.info("Phase '%s' skipped (excluded)", name) + continue + logger.info("=== Phase: %s ===", name) + try: + phase_cls(ctx).run(report) + except CloneError as e: + report.aborted = True + report.abort_reason = str(e) + logger.error("Phase '%s' aborted: %s", name, e) + break - report.remap_snapshot = ctx.remap.snapshot() - return report + report.remap_snapshot = ctx.remap.snapshot() + return report + finally: + src_client.close() + tgt_client.close() diff --git a/src/unstract/clone/phases/api_deployment.py b/src/unstract/clone/phases/api_deployment.py index 00c85b7..bbd20ed 100644 --- a/src/unstract/clone/phases/api_deployment.py +++ b/src/unstract/clone/phases/api_deployment.py @@ -70,7 +70,8 @@ def _clone_one(self, src: dict[str, Any], result: PhaseResult) -> None: if not tgt_wf_id: logger.warning( "no workflow remap for api_deployment '%s' (src workflow %s) — skipping", - api_name, src_wf_id, + api_name, + src_wf_id, ) result.skipped += 1 return @@ -94,7 +95,9 @@ def _clone_one(self, src: dict[str, Any], result: PhaseResult) -> None: result.adopted += 1 logger.info( "adopted api_deployment '%s' src=%s -> tgt=%s", - api_name, src_id, tgt["id"], + api_name, + src_id, + tgt["id"], ) elif self.ctx.options.dry_run: result.skipped += 1 @@ -103,22 +106,32 @@ def _clone_one(self, src: dict[str, Any], result: PhaseResult) -> None: ) return else: - remapped = remap_uuids(src, self.ctx.remap) + try: + # list serializer can strip fields the create serializer expects. + full_src = self.ctx.source.get_api_deployment(src_id) + except Exception as e: + logger.exception( + "Failed to GET source api_deployment %s: %s", api_name, e + ) + result.failed += 1 + result.errors.append(f"GET src api_deployment {api_name}: {e}") + return + remapped = remap_uuids(full_src, self.ctx.remap) payload = build_post_payload(remapped, self._writable) payload["workflow"] = tgt_wf_id try: tgt = self.ctx.target.create_api_deployment(payload) except Exception as e: - logger.exception( - "Failed to create api_deployment %s: %s", api_name, e - ) + logger.exception("Failed to create api_deployment %s: %s", api_name, e) result.failed += 1 result.errors.append(f"create {api_name}: {e}") return result.created += 1 logger.info( "created api_deployment '%s' src=%s -> tgt=%s", - api_name, src_id, tgt["id"], + api_name, + src_id, + tgt["id"], ) self._warn_if_extra_source_keys(src_id, api_name) @@ -128,7 +141,14 @@ def _warn_if_extra_source_keys(self, src_deployment_id: str, name: str) -> None: try: keys = self.ctx.source.list_api_deployment_keys(src_deployment_id) except Exception as e: - logger.debug("Could not list source keys for api_deployment %s: %s", name, e) + # WARNING (not DEBUG) — the operator needs to know we couldn't + # check whether they have additional keys to recreate manually. + logger.warning( + "Could not list source keys for api_deployment %s " + "(extra-key check skipped; re-verify in source UI): %s", + name, + e, + ) return active = [k for k in keys if k.get("is_active")] if len(active) > 1: @@ -136,5 +156,6 @@ def _warn_if_extra_source_keys(self, src_deployment_id: str, name: str) -> None: "source api_deployment '%s' had %d active API keys; " "target has only the auto-provisioned default — " "re-create the rest manually if your clients depend on them", - name, len(active), + name, + len(active), ) diff --git a/src/unstract/clone/phases/custom_tool.py b/src/unstract/clone/phases/custom_tool.py index 8cb03ca..f633e74 100644 --- a/src/unstract/clone/phases/custom_tool.py +++ b/src/unstract/clone/phases/custom_tool.py @@ -80,7 +80,7 @@ def run(self, report: CloneReport) -> PhaseResult: logger.exception("Failed to list target tools: %s", e) result.failed += 1 result.errors.append(f"list target tools: {e}") - return + return result for summary in src_tools: self._clone_one(summary, target_tools, result) return result @@ -102,12 +102,12 @@ def _clone_one( result.errors.append(f"export src tool {tool_name}: {e}") return - match = next( - (t for t in target_tools if t["tool_name"] == tool_name), None - ) + match = next((t for t in target_tools if t["tool_name"] == tool_name), None) if match is not None: - tgt_tool_id = self._adopt(match, export_data, result, tool_name, src_tool_id) + tgt_tool_id = self._adopt( + match, export_data, result, tool_name, src_tool_id + ) else: tgt_tool_id = self._create_fresh( export_data, src_tool_id, tool_name, result @@ -116,29 +116,27 @@ def _clone_one( # with the same name (uncommon but legal) adopts this new # row instead of trying to re-create it. if tgt_tool_id is not None: - target_tools.append( - {"tool_id": tgt_tool_id, "tool_name": tool_name} - ) + target_tools.append({"tool_id": tgt_tool_id, "tool_name": tool_name}) if tgt_tool_id is None: return self.ctx.remap.record("custom_tool", src_tool_id, tgt_tool_id) + if self.ctx.options.dry_run: + return + try: self.ctx.target.export_custom_tool(tgt_tool_id) - logger.info("republished registry for tool '%s' tgt=%s", tool_name, tgt_tool_id) + logger.info( + "republished registry for tool '%s' tgt=%s", tool_name, tgt_tool_id + ) except Exception as e: logger.exception("Registry republish failed for tool %s: %s", tool_name, e) result.failed += 1 result.errors.append(f"export {tool_name}: {e}") return - # Record registry remap so ToolInstancePhase can rewrite - # ToolInstance.tool_id (which stores a registry UUID as CharField). - # Source registry exists only if the operator already published - # the tool there; unpublished source tools simply produce no - # ToolInstance rows for downstream to remap. try: src_regs = self.ctx.source.list_registries(custom_tool=src_tool_id) tgt_regs = self.ctx.target.list_registries(custom_tool=tgt_tool_id) @@ -146,8 +144,11 @@ def _clone_one( logger.warning( "registry remap lookup failed for tool '%s' " "(downstream ToolInstance clone may skip): %s", - tool_name, e, + tool_name, + e, ) + result.failed += 1 + result.errors.append(f"registry remap lookup {tool_name}: {e}") return if src_regs and tgt_regs: @@ -175,7 +176,9 @@ def _adopt( result.skipped += 1 logger.info( "[dry-run] would sync prompts into adopted tool '%s' src=%s -> tgt=%s", - tool_name, src_tool_id, tgt_tool_id, + tool_name, + src_tool_id, + tgt_tool_id, ) return tgt_tool_id @@ -190,7 +193,9 @@ def _adopt( result.adopted += 1 logger.info( "adopted tool '%s' src=%s -> tgt=%s (prompts re-synced)", - tool_name, src_tool_id, tgt_tool_id, + tool_name, + src_tool_id, + tgt_tool_id, ) return tgt_tool_id @@ -228,7 +233,10 @@ def _create_fresh( result.created += 1 logger.info( "created tool '%s' src=%s -> tgt=%s (needs_adapter_config=%s)", - tool_name, src_tool_id, tgt_tool_id, tgt.get("needs_adapter_config"), + tool_name, + src_tool_id, + tgt_tool_id, + tgt.get("needs_adapter_config"), ) return tgt_tool_id @@ -268,7 +276,8 @@ def _resolve_target_adapter_ids( if not adapter_name: logger.warning( "source default profile for tool '%s' missing adapter '%s'", - tool_name, src_field, + tool_name, + src_field, ) return None try: @@ -276,13 +285,17 @@ def _resolve_target_adapter_ids( except Exception as e: logger.exception( "list_adapters lookup failed for %s on tool '%s': %s", - adapter_name, tool_name, e, + adapter_name, + tool_name, + e, ) return None if not matches: logger.warning( "no target adapter named '%s' for field %s on tool '%s'", - adapter_name, src_field, tool_name, + adapter_name, + src_field, + tool_name, ) return None resolved[form_field] = matches[0]["id"] diff --git a/src/unstract/clone/phases/files.py b/src/unstract/clone/phases/files.py index da403f2..d28ce6f 100644 --- a/src/unstract/clone/phases/files.py +++ b/src/unstract/clone/phases/files.py @@ -20,8 +20,7 @@ Concurrency is 1 per phase by design — the Platform API endpoint holds a cloud worker for the whole upload, and uploads are not chunked on the BE -helper today. See ``docs/internal/files-clone-plan.md`` for the -sizing rationale. +helper today. """ from __future__ import annotations @@ -63,7 +62,9 @@ def run(self, report: CloneReport) -> PhaseResult: strategy = self.ctx.options.file_strategy logger.info( "files phase: strategy=%s tools=%d cap=%d bytes", - strategy, len(tool_remap), self.ctx.options.max_file_size, + strategy, + len(tool_remap), + self.ctx.options.max_file_size, ) for src_tool_id, tgt_tool_id in tool_remap.items(): @@ -73,16 +74,17 @@ def run(self, report: CloneReport) -> PhaseResult: except Exception as e: logger.exception( "files: failed to list source DM rows for tool %s: %s", - tool_name, e, + tool_name, + e, ) result.failed += 1 - result.errors.append( - f"list source docs {tool_name}: {e}" - ) + result.errors.append(f"list source docs {tool_name}: {e}") continue if strategy == "skip": - self._emit_skip(src_docs, src_tool_id, tgt_tool_id, tool_name, report, result) + self._emit_skip( + src_docs, src_tool_id, tgt_tool_id, tool_name, report, result + ) continue self._clone_tool( @@ -105,7 +107,8 @@ def _clone_tool( except Exception as e: logger.exception( "files: failed to list target DM rows for tool %s: %s", - tool_name, e, + tool_name, + e, ) result.failed += 1 result.errors.append(f"list target docs {tool_name}: {e}") @@ -116,19 +119,30 @@ def _clone_tool( file_name = doc.get("document_name") src_document_id = doc.get("document_id") if not file_name or not src_document_id: + result.skipped += 1 + result.errors.append( + f"malformed source DM row on tool={tool_name}: {doc!r}" + ) + logger.warning( + "files: skipping malformed source DM row on tool=%s: %r", + tool_name, + doc, + ) continue if file_name in target_names: result.skipped += 1 logger.debug( "files: already present on target tool=%s file=%s", - tool_name, file_name, + tool_name, + file_name, ) continue if self.ctx.options.dry_run: result.skipped += 1 logger.info( "[dry-run] files: would clone tool=%s file=%s", - tool_name, file_name, + tool_name, + file_name, ) continue self._clone_one_file( @@ -142,9 +156,7 @@ def _clone_tool( ) if not self.ctx.options.dry_run: - self._ensure_default_doc( - src_tool_id, tgt_tool_id, tool_name, src_docs - ) + self._ensure_default_doc(src_tool_id, tgt_tool_id, tool_name, src_docs) def _clone_one_file( self, @@ -166,7 +178,9 @@ def _clone_one_file( except Exception as e: logger.exception( "files: download failed tool=%s file=%s: %s", - tool_name, file_name, e, + tool_name, + file_name, + e, ) result.failed += 1 report.failed_files.append( @@ -184,8 +198,11 @@ def _clone_one_file( if raw is None: logger.warning( "files: unsupported mime tool=%s file=%s mime=%s", - tool_name, file_name, mime, + tool_name, + file_name, + mime, ) + result.skipped += 1 report.unsupported_files.append( { "tool_id": tgt_tool_id, @@ -197,6 +214,7 @@ def _clone_one_file( return if len(raw) > self.ctx.options.max_file_size: + result.skipped += 1 report.oversize_files.append( { "tool_id": tgt_tool_id, @@ -208,7 +226,10 @@ def _clone_one_file( ) logger.info( "files: oversize tool=%s file=%s size=%d cap=%d", - tool_name, file_name, len(raw), self.ctx.options.max_file_size, + tool_name, + file_name, + len(raw), + self.ctx.options.max_file_size, ) return @@ -222,7 +243,9 @@ def _clone_one_file( except Exception as e: logger.exception( "files: upload failed tool=%s file=%s: %s", - tool_name, file_name, e, + tool_name, + file_name, + e, ) result.failed += 1 report.failed_files.append( @@ -247,7 +270,9 @@ def _clone_one_file( ) logger.info( "files: uploaded tool=%s file=%s size=%d", - tool_name, file_name, len(raw), + tool_name, + file_name, + len(raw), ) def _emit_skip( @@ -275,7 +300,8 @@ def _emit_skip( result.skipped += 1 logger.info( "files: skip mode emitted %d filenames for tool=%s", - len(src_docs), tool_name, + len(src_docs), + tool_name, ) def _decode_payload( @@ -318,7 +344,8 @@ def _ensure_default_doc( except Exception as e: logger.warning( "files: skipping default-doc set for tool=%s — fetch tgt failed: %s", - tool_name, e, + tool_name, + e, ) return @@ -334,7 +361,8 @@ def _ensure_default_doc( except Exception as e: logger.warning( "files: skipping default-doc set for tool=%s — list tgt docs failed: %s", - tool_name, e, + tool_name, + e, ) return if not tgt_docs: @@ -352,9 +380,7 @@ def _ensure_default_doc( "files: set default doc tool=%s doc_id=%s", tool_name, chosen_id ) except Exception as e: - logger.warning( - "files: PATCH default doc failed tool=%s: %s", tool_name, e - ) + logger.warning("files: PATCH default doc failed tool=%s: %s", tool_name, e) def _pick_default_doc_id( self, @@ -373,20 +399,27 @@ def _pick_default_doc_id( logger.debug( "files: source CustomTool fetch failed for tool=%s (%s); " "falling back to first target doc", - tool_name, e, + tool_name, + e, ) src_output = None if src_output: src_name = next( - (d.get("document_name") for d in src_docs - if d.get("document_id") == src_output), + ( + d.get("document_name") + for d in src_docs + if d.get("document_id") == src_output + ), None, ) if src_name: matched = next( - (d.get("document_id") for d in tgt_docs - if d.get("document_name") == src_name), + ( + d.get("document_id") + for d in tgt_docs + if d.get("document_name") == src_name + ), None, ) if matched: @@ -395,11 +428,23 @@ def _pick_default_doc_id( return tgt_docs[0].get("document_id") def _lookup_tool_name(self, tgt_tool_id: str) -> str | None: - # CustomToolPhase doesn't record names; fetch lazily for log clarity. - # One call per tool is cheap relative to the per-file traffic. + # Cosmetic helper for logs only — never let a transport hiccup here + # mask a downstream "tool was deleted" or auth failure. try: tools = self.ctx.target.list_custom_tools() - except Exception: + except PlatformAPIError as e: + logger.warning( + "files: list_custom_tools failed during name lookup (%s); " + "log lines will fall back to tool ids", + e, + ) + return None + except (requests.ConnectionError, requests.Timeout) as e: + logger.warning( + "files: transport error during tool-name lookup (%s); " + "log lines will fall back to tool ids", + e, + ) return None for t in tools: if t.get("tool_id") == tgt_tool_id: @@ -418,7 +463,11 @@ def _with_retry(self, fn: Any, *, op: str) -> Any: sleep = _RETRY_BACKOFF_BASE_SECONDS * (2 ** (attempt - 1)) logger.warning( "files: retry %d/%d for %s after %d: sleeping %.1fs", - attempt, _MAX_RETRIES, op, e.status_code, sleep, + attempt, + _MAX_RETRIES, + op, + e.status_code, + sleep, ) time.sleep(sleep) except (requests.ConnectionError, requests.Timeout) as e: @@ -428,7 +477,11 @@ def _with_retry(self, fn: Any, *, op: str) -> Any: sleep = _RETRY_BACKOFF_BASE_SECONDS * (2 ** (attempt - 1)) logger.warning( "files: retry %d/%d for %s after %s: sleeping %.1fs", - attempt, _MAX_RETRIES, op, type(e).__name__, sleep, + attempt, + _MAX_RETRIES, + op, + type(e).__name__, + sleep, ) time.sleep(sleep) assert last_exc is not None diff --git a/src/unstract/clone/phases/pipeline.py b/src/unstract/clone/phases/pipeline.py index 15b1835..8142128 100644 --- a/src/unstract/clone/phases/pipeline.py +++ b/src/unstract/clone/phases/pipeline.py @@ -55,7 +55,8 @@ def run(self, report: CloneReport) -> PhaseResult: if skipped_types: logger.info( "Found %d source pipeline(s); skipping %d of unsupported type (DEFAULT/APP)", - len(src_pipelines), skipped_types, + len(src_pipelines), + skipped_types, ) else: logger.info("Found %d source pipeline(s)", len(src_pipelines)) @@ -78,7 +79,8 @@ def _clone_one(self, src: dict[str, Any], result: PhaseResult) -> None: if not tgt_wf_id: logger.warning( "no workflow remap for pipeline '%s' (src workflow %s) — skipping", - name, src_wf_id, + name, + src_wf_id, ) result.skipped += 1 return @@ -106,7 +108,15 @@ def _clone_one(self, src: dict[str, Any], result: PhaseResult) -> None: logger.info("[dry-run] would create pipeline '%s' src=%s", name, src_id) return else: - remapped = remap_uuids(src, self.ctx.remap) + try: + # list serializer can strip fields the create serializer expects. + full_src = self.ctx.source.get_pipeline(src_id) + except Exception as e: + logger.exception("Failed to GET source pipeline %s: %s", name, e) + result.failed += 1 + result.errors.append(f"GET src pipeline {name}: {e}") + return + remapped = remap_uuids(full_src, self.ctx.remap) payload = build_post_payload(remapped, self._writable) payload["workflow"] = tgt_wf_id try: @@ -128,7 +138,14 @@ def _warn_if_extra_source_keys(self, src_pipeline_id: str, name: str) -> None: try: keys = self.ctx.source.list_pipeline_keys(src_pipeline_id) except Exception as e: - logger.debug("Could not list source keys for pipeline %s: %s", name, e) + # WARNING (not DEBUG) — the operator needs to know we couldn't + # check whether they have additional keys to recreate manually. + logger.warning( + "Could not list source keys for pipeline %s " + "(extra-key check skipped; re-verify in source UI): %s", + name, + e, + ) return active = [k for k in keys if k.get("is_active")] if len(active) > 1: @@ -136,5 +153,6 @@ def _warn_if_extra_source_keys(self, src_pipeline_id: str, name: str) -> None: "source pipeline '%s' had %d active API keys; " "target has only the auto-provisioned default — " "re-create the rest manually if your clients depend on them", - name, len(active), + name, + len(active), ) diff --git a/src/unstract/clone/phases/tool_instance.py b/src/unstract/clone/phases/tool_instance.py index ac41a9b..67e5ca5 100644 --- a/src/unstract/clone/phases/tool_instance.py +++ b/src/unstract/clone/phases/tool_instance.py @@ -39,12 +39,8 @@ "[NEEDS UPDATE]", ) -# Identity fields that point at backend rows by primary key. They were -# populated server-side at create time on source and must NOT be carried -# across orgs — the target's create_tool_instance has already set the -# correct target values. Leaking source ids here makes the structure -# tool fetch the source registry at runtime (platform-service looks up -# registries by id only, no org scope) and load the wrong adapters. +# Fields tied to the source row's own ids — never valid on the target. +# Always rewrite these with target values before PATCHing. _SOURCE_IDENTITY_FIELDS: tuple[str, ...] = ( "prompt_registry_id", "tool_instance_id", @@ -86,7 +82,9 @@ def _clone_workflow_tools( try: src_instances = self.ctx.source.list_tool_instances(workflow_id=src_wf_id) except Exception as e: - logger.exception("Failed to list source tool_instances for wf %s: %s", src_wf_id, e) + logger.exception( + "Failed to list source tool_instances for wf %s: %s", src_wf_id, e + ) result.failed += 1 result.errors.append(f"list src tool_instances {src_wf_id}: {e}") return @@ -97,7 +95,8 @@ def _clone_workflow_tools( # Backend enforces ≤1; warn loudly if invariant breaks on source. logger.warning( "source workflow %s has %d tool_instances (expected ≤1) — migrating first only", - src_wf_id, len(src_instances), + src_wf_id, + len(src_instances), ) src_ti = src_instances[0] @@ -109,7 +108,8 @@ def _clone_workflow_tools( logger.warning( "skipping tool_instance %s — no registry remap for tool_id %s " "(custom tool likely unpublished on source)", - src_ti_id, src_tool_id, + src_ti_id, + src_tool_id, ) result.skipped += 1 return @@ -117,24 +117,40 @@ def _clone_workflow_tools( try: existing = self.ctx.target.list_tool_instances(workflow_id=tgt_wf_id) except Exception as e: - logger.exception("Failed to list target tool_instances for wf %s: %s", tgt_wf_id, e) + logger.exception( + "Failed to list target tool_instances for wf %s: %s", tgt_wf_id, e + ) result.failed += 1 result.errors.append(f"list tgt tool_instances {tgt_wf_id}: {e}") return if existing: tgt_ti = existing[0] + if self.ctx.options.dry_run: + result.skipped += 1 + logger.info( + "[dry-run] would re-PATCH metadata on adopted tool_instance " + "src=%s -> tgt=%s (workflow %s)", + src_ti_id, + tgt_ti["id"], + tgt_wf_id, + ) + self.ctx.remap.record("tool_instance", src_ti_id, tgt_ti["id"]) + return result.adopted += 1 logger.info( "adopted tool_instance src=%s -> tgt=%s (workflow %s)", - src_ti_id, tgt_ti["id"], tgt_wf_id, + src_ti_id, + tgt_ti["id"], + tgt_wf_id, ) elif self.ctx.options.dry_run: result.skipped += 1 logger.info( "[dry-run] would create tool_instance for tgt workflow %s " "(src tool_instance %s)", - tgt_wf_id, src_ti_id, + tgt_wf_id, + src_ti_id, ) return else: @@ -152,7 +168,9 @@ def _clone_workflow_tools( result.created += 1 logger.info( "created tool_instance src=%s -> tgt=%s (workflow %s)", - src_ti_id, tgt_ti["id"], tgt_wf_id, + src_ti_id, + tgt_ti["id"], + tgt_wf_id, ) # PATCH the metadata regardless of created/adopted — keeps tool config @@ -164,13 +182,22 @@ def _clone_workflow_tools( "skipping metadata PATCH for tool_instance src=%s tgt=%s — " "source metadata carries broken adapter refs %s; " "row exists with backend defaults, re-bind in UI", - src_ti_id, tgt_ti["id"], broken, + src_ti_id, + tgt_ti["id"], + broken, ) result.errors.append( f"stale adapter refs on src tool_instance {src_ti_id}: {broken}" ) else: - patch_metadata = _strip_source_identity(src_metadata) + # PATCH overwrites the whole metadata dict, so we must include + # the target's own identity fields or the runtime sees them + # as empty. tool_id IS the prompt_registry_id. + patch_metadata = { + **_strip_source_identity(src_metadata), + "prompt_registry_id": tgt_tool_id, + "tool_instance_id": tgt_ti["id"], + } try: self.ctx.target.update_tool_instance_metadata( tgt_ti["id"], patch_metadata diff --git a/src/unstract/clone/phases/workflow_endpoint.py b/src/unstract/clone/phases/workflow_endpoint.py index a54f34a..8c235d0 100644 --- a/src/unstract/clone/phases/workflow_endpoint.py +++ b/src/unstract/clone/phases/workflow_endpoint.py @@ -46,7 +46,9 @@ def run(self, report: CloneReport) -> PhaseResult: result = report.get_phase(self.name) workflow_remap = self.ctx.remap.snapshot().get("workflow", {}) if not workflow_remap: - logger.info("No workflows in remap; nothing to do for workflow_endpoint phase") + logger.info( + "No workflows in remap; nothing to do for workflow_endpoint phase" + ) return result for src_wf_id, tgt_wf_id in workflow_remap.items(): @@ -90,12 +92,11 @@ def _clone_workflow_endpoints( # workflow create flow failed earlier — surface loudly. logger.warning( "target workflow %s missing %s endpoint — skipping", - tgt_wf_id, etype, + tgt_wf_id, + etype, ) result.failed += 1 - result.errors.append( - f"missing tgt {etype} endpoint for wf {tgt_wf_id}" - ) + result.errors.append(f"missing tgt {etype} endpoint for wf {tgt_wf_id}") continue self._patch_endpoint(src_ep, tgt_ep, result) @@ -111,7 +112,9 @@ def _patch_endpoint( result.skipped += 1 logger.info( "[dry-run] would PATCH %s endpoint src=%s -> tgt=%s", - etype, src_ep_id, tgt_ep_id, + etype, + src_ep_id, + tgt_ep_id, ) return @@ -127,7 +130,10 @@ def _patch_endpoint( logger.warning( "skipping %s endpoint src=%s tgt=%s — source connector %s " "has no target remap; would silently unset connector", - etype, src_ep_id, tgt_ep_id, src_conn_id, + etype, + src_ep_id, + tgt_ep_id, + src_conn_id, ) result.skipped += 1 result.errors.append( @@ -136,11 +142,18 @@ def _patch_endpoint( ) return + # connection_type is a required enum on the backend; pass through + # source's value (incl. None) verbatim so the backend's validation + # surfaces the real problem rather than us papering over with "". payload: dict[str, Any] = { - "connection_type": src_ep.get("connection_type") or "", - "configuration": remap_uuids(src_ep.get("configuration") or {}, self.ctx.remap), + "configuration": remap_uuids( + src_ep.get("configuration") or {}, self.ctx.remap + ), "connector_instance_id": tgt_conn_id, } + src_connection_type = src_ep.get("connection_type") + if src_connection_type is not None: + payload["connection_type"] = src_connection_type try: self.ctx.target.update_workflow_endpoint(tgt_ep_id, payload) @@ -155,6 +168,9 @@ def _patch_endpoint( result.created += 1 logger.info( "patched %s endpoint src=%s -> tgt=%s (connector %s)", - etype, src_ep_id, tgt_ep_id, tgt_conn_id, + etype, + src_ep_id, + tgt_ep_id, + tgt_conn_id, ) self.ctx.remap.record("workflow_endpoint", src_ep_id, tgt_ep_id) diff --git a/tests/clone/test_api_deployment_phase.py b/tests/clone/test_api_deployment_phase.py index df84c91..dc25d7a 100644 --- a/tests/clone/test_api_deployment_phase.py +++ b/tests/clone/test_api_deployment_phase.py @@ -24,7 +24,6 @@ from unstract.clone.phases.api_deployment import APIDeploymentPhase from unstract.clone.report import CloneReport - API_DEPLOYMENT_POST_SCHEMA = frozenset( { "display_name", @@ -54,6 +53,12 @@ def list_api_deployments(self, *, api_name: str | None = None): result = [d for d in result if d["api_name"] == api_name] return list(result) + def get_api_deployment(self, deployment_id: str) -> dict: + for d in self.deployments: + if d["id"] == deployment_id: + return dict(d) + raise KeyError(deployment_id) + def create_api_deployment(self, payload: dict) -> dict: new = dict(payload) new["id"] = f"tgt-dep-{self._next:04d}" @@ -111,9 +116,7 @@ def test_happy_path_creates_deployment_with_remapped_workflow(): def test_adopts_existing_deployment_by_api_name(): src = FakeClient([_src_deployment("src-dep-1", "invoices_api", "wf-src-1")]) - tgt = FakeClient( - [{"id": "tgt-existing", "api_name": "invoices_api"}] - ) + tgt = FakeClient([{"id": "tgt-existing", "api_name": "invoices_api"}]) remap = RemapTable() remap.record("workflow", "wf-src-1", "wf-tgt-1") ctx = _ctx(src, tgt, remap=remap) diff --git a/tests/clone/test_cli.py b/tests/clone/test_cli.py new file mode 100644 index 0000000..e4ac623 --- /dev/null +++ b/tests/clone/test_cli.py @@ -0,0 +1,129 @@ +"""Tests for the click CLI wiring in ``unstract.clone.cli``. + +Coverage: +- ``_parse_size`` accepts bare integers, K/M/G suffixes, decimals. +- ``--max-file-size 0`` propagates as 0 (force every file to manual list), + not the default cap — distinguished from the unparseable case. +""" + +from __future__ import annotations + +import pytest +from click.testing import CliRunner + +from unstract.clone.cli import _parse_size, cli +from unstract.clone.context import DEFAULT_MAX_FILE_SIZE, CloneOptions +from unstract.clone.report import CloneReport, Endpoint + + +def test_parse_size_bare_int_is_bytes(): + assert _parse_size("25") == 25 + + +def test_parse_size_accepts_kb_mb_gb_units(): + assert _parse_size("25MB") == 25 * 1024 * 1024 + assert _parse_size("1.5GB") == int(1.5 * 1024 * 1024 * 1024) + assert _parse_size("512K") == 512 * 1024 + + +def test_parse_size_zero_returns_zero(): + # Regression for `cap_bytes or DEFAULT` — must not coerce 0 to the + # default. CLI flag --max-file-size 0 means "every file goes to the + # oversize/manual-upload list". + assert _parse_size("0") == 0 + + +def test_parse_size_unknown_unit_raises(): + import click + + with pytest.raises(click.BadParameter): + _parse_size("10XB") + + +def test_parse_size_unparseable_raises(): + import click + + with pytest.raises(click.BadParameter): + _parse_size("not-a-size") + + +def test_cli_max_file_size_zero_propagates_to_options(monkeypatch): + captured: dict = {} + + def fake_clone(source, target, options=None): + captured["options"] = options + return CloneReport( + source=Endpoint( + base_url=source.base_url, organization_id=source.organization_id + ), + target=Endpoint( + base_url=target.base_url, organization_id=target.organization_id + ), + ) + + monkeypatch.setattr("unstract.clone.cli.run_clone", fake_clone) + + result = CliRunner().invoke( + cli, + [ + "clone", + "--source-url", + "http://src", + "--source-org", + "src", + "--source-key", + "sk", + "--target-url", + "http://tgt", + "--target-org", + "tgt", + "--target-key", + "tk", + "--max-file-size", + "0", + ], + ) + + assert result.exit_code == 0, result.output + opts: CloneOptions = captured["options"] + assert opts.max_file_size == 0 + + +def test_cli_max_file_size_default_when_flag_omitted(monkeypatch): + captured: dict = {} + + def fake_clone(source, target, options=None): + captured["options"] = options + return CloneReport( + source=Endpoint( + base_url=source.base_url, organization_id=source.organization_id + ), + target=Endpoint( + base_url=target.base_url, organization_id=target.organization_id + ), + ) + + monkeypatch.setattr("unstract.clone.cli.run_clone", fake_clone) + + result = CliRunner().invoke( + cli, + [ + "clone", + "--source-url", + "http://src", + "--source-org", + "src", + "--source-key", + "sk", + "--target-url", + "http://tgt", + "--target-org", + "tgt", + "--target-key", + "tk", + ], + ) + + assert result.exit_code == 0, result.output + opts: CloneOptions = captured["options"] + assert opts.max_file_size == DEFAULT_MAX_FILE_SIZE diff --git a/tests/clone/test_client.py b/tests/clone/test_client.py new file mode 100644 index 0000000..9fa3dca --- /dev/null +++ b/tests/clone/test_client.py @@ -0,0 +1,145 @@ +"""Tests for ``PlatformClient`` HTTP layer. + +Coverage: +- URL composition honours base_url, api_path_prefix, organization_id. +- Bearer auth header present on every request. +- Non-2xx response raises ``PlatformAPIError`` with status_code + body. +- 204 / empty body returns ``None`` instead of raising on .json(). +- ``get_post_schema`` parses DRF ``actions.POST`` and caches per path. +- ``close()`` shuts the underlying session; context manager works. +""" + +from __future__ import annotations + +from unittest.mock import MagicMock + +import pytest + +from unstract.clone.client import PlatformClient +from unstract.clone.context import OrgEndpoint +from unstract.clone.exceptions import PlatformAPIError + + +def _endpoint() -> OrgEndpoint: + return OrgEndpoint( + base_url="https://api.example.com", + organization_id="org_abc", + platform_key="plat-key-xyz", + ) + + +def _fake_response(status: int, payload=None, text: str = "") -> MagicMock: + resp = MagicMock() + resp.status_code = status + resp.text = text + resp.content = b"" if payload is None and not text else b"x" + resp.json.return_value = payload + return resp + + +def _client_with_mock( + payload=None, status: int = 200, text: str = "" +) -> tuple[PlatformClient, MagicMock]: + client = PlatformClient(_endpoint()) + mock_request = MagicMock(return_value=_fake_response(status, payload, text)) + client._session.request = mock_request + return client, mock_request + + +def test_url_composition_includes_org_and_api_prefix(): + client, mock_request = _client_with_mock(payload=[]) + client.list_adapters() + call = mock_request.call_args + assert call.args[0] == "GET" + assert call.args[1] == "https://api.example.com/api/v1/unstract/org_abc/adapter/" + + +def test_bearer_token_sent_on_session(): + client, _ = _client_with_mock(payload=[]) + assert client._session.headers["Authorization"] == "Bearer plat-key-xyz" + assert client._session.headers["Accept"] == "application/json" + + +def test_non_2xx_raises_platform_api_error_with_status_and_body(): + client, _ = _client_with_mock(status=404, text="not found") + with pytest.raises(PlatformAPIError) as exc_info: + client.list_adapters() + err = exc_info.value + assert err.status_code == 404 + assert "not found" in err.body + + +def test_500_with_long_body_truncated_to_2000_chars(): + big = "x" * 5000 + client, _ = _client_with_mock(status=500, text=big) + with pytest.raises(PlatformAPIError) as exc_info: + client.list_adapters() + assert len(exc_info.value.body) == 2000 + + +def test_204_no_content_returns_none(): + client = PlatformClient(_endpoint()) + resp = MagicMock() + resp.status_code = 204 + resp.content = b"" + client._session.request = MagicMock(return_value=resp) + assert client._request("DELETE", "tag/abc/") is None + + +def test_get_post_schema_parses_options_and_caches(): + options_body = { + "actions": { + "POST": { + "name": {"read_only": False}, + "id": {"read_only": True}, + "shared_to_org": {"read_only": False}, + # No read_only key → treated as writable. + "description": {}, + } + } + } + client, mock_request = _client_with_mock(payload=options_body) + writable = client.get_post_schema("adapter/") + assert writable == frozenset({"name", "shared_to_org", "description"}) + # second call hits cache — no extra HTTP. + writable2 = client.get_post_schema("adapter/") + assert writable2 is writable + assert mock_request.call_count == 1 + + +def test_get_post_schema_handles_missing_actions_block(): + client, _ = _client_with_mock(payload={}) + assert client.get_post_schema("connector/") == frozenset() + + +def test_close_shuts_session(): + client = PlatformClient(_endpoint()) + sess = client._session + sess.close = MagicMock() + client.close() + sess.close.assert_called_once() + + +def test_context_manager_closes_on_exit(): + with PlatformClient(_endpoint()) as client: + client._session.close = MagicMock() + sess_close = client._session.close + sess_close.assert_called_once() + + +def test_list_endpoint_unwraps_paginated_envelope(): + client, _ = _client_with_mock(payload={"results": [{"id": "a"}, {"id": "b"}]}) + items = client.list_tags() + assert [i["id"] for i in items] == ["a", "b"] + + +def test_list_endpoint_accepts_bare_list(): + client, _ = _client_with_mock(payload=[{"id": "a"}]) + items = client.list_tags() + assert items == [{"id": "a"}] + + +def test_options_response_with_null_body_still_yields_empty_schema(): + # Some deployments return 200 with no body on OPTIONS. + client, _ = _client_with_mock(payload=None, text="") + assert client.get_post_schema("pipeline/") == frozenset() diff --git a/tests/clone/test_custom_tool_phase.py b/tests/clone/test_custom_tool_phase.py index ad1dfd2..dac4a13 100644 --- a/tests/clone/test_custom_tool_phase.py +++ b/tests/clone/test_custom_tool_phase.py @@ -25,7 +25,6 @@ from unstract.clone.phases.custom_tool import CustomToolPhase from unstract.clone.report import CloneReport - ADAPTER_NAMES = { "llm": "gpt4", "embedding_model": "ada-embed", @@ -168,7 +167,12 @@ def _src_default_profile(*, nested: bool = False) -> dict: def _src_export_blob(tool_name: str) -> dict: return { - "tool_metadata": {"tool_name": tool_name, "description": "x", "author": "a", "icon": None}, + "tool_metadata": { + "tool_name": tool_name, + "description": "x", + "author": "a", + "icon": None, + }, "tool_settings": {"preamble": "p", "postamble": "q"}, "default_profile_settings": { "chunk_size": 1024, @@ -179,7 +183,11 @@ def _src_export_blob(tool_name: str) -> dict: "profile_name": "Default", }, "prompts": [ - {"prompt_key": "field_a", "prompt": "What is field_a?", "sequence_number": 1} + { + "prompt_key": "field_a", + "prompt": "What is field_a?", + "sequence_number": 1, + } ], "export_metadata": {"exported_at": "2026-05-24T00:00:00Z"}, } @@ -301,6 +309,27 @@ def test_dry_run_makes_no_writes(): assert tgt.export_tool_calls == [] +def test_dry_run_on_adopt_path_does_not_republish_registry(): + # Adopt path used to return tgt_tool_id even on dry-run, falling + # through to export_custom_tool (a real POST to the target). + src = FakeClient() + tgt = FakeClient() + _preload_source_tool(src, "src-tool-x", "Invoice Extractor") + tgt.tools["tgt-existing"] = {"tool_name": "Invoice Extractor"} + _seed_target_adapters(tgt) + ctx = _ctx(src, tgt, dry_run=True) + + result = CustomToolPhase(ctx).run(CloneReport()) + + assert result.skipped == 1 + assert tgt.sync_calls == [] + assert tgt.import_calls == [] + # Critical regression: registry republish must NOT fire on dry-run. + assert tgt.export_tool_calls == [] + # Remap still recorded so downstream dry-run output stays coherent. + assert ctx.remap.resolve("custom_tool", "src-tool-x") == "tgt-existing" + + def test_missing_target_adapter_fails_tool_cleanly(): src = FakeClient() tgt = FakeClient() diff --git a/tests/clone/test_files_phase.py b/tests/clone/test_files_phase.py index 1392ca7..50f739e 100644 --- a/tests/clone/test_files_phase.py +++ b/tests/clone/test_files_phase.py @@ -29,7 +29,6 @@ from unstract.clone.phases.files import FilesPhase from unstract.clone.report import CloneReport - SRC_ENDPOINT = OrgEndpoint( base_url="http://src", organization_id="src-org", platform_key="src-key" ) @@ -53,9 +52,7 @@ def __init__( k: list(v) for k, v in (documents or {}).items() } # (tool_id, file_name) -> {"data": ..., "mime_type": ...} - self._file_payloads: dict[tuple[str, str], dict] = dict( - file_payloads or {} - ) + self._file_payloads: dict[tuple[str, str], dict] = dict(file_payloads or {}) self._tools = list(tools or []) self.uploaded: list[dict[str, Any]] = [] self.list_calls: list[str] = [] @@ -124,8 +121,9 @@ def update_custom_tool(self, tool_id: str, body: dict) -> dict: return {} -def _ctx(src: FakeClient, tgt: FakeClient, *, remap: RemapTable | None = None, - **opts) -> CloneContext: +def _ctx( + src: FakeClient, tgt: FakeClient, *, remap: RemapTable | None = None, **opts +) -> CloneContext: remap = remap or RemapTable() return CloneContext( source=src, @@ -156,7 +154,9 @@ def test_happy_path_uploads_pdf_and_text(): ("src-1", "notes.txt"): _text_payload("hello world"), }, ) - tgt = FakeClient(endpoint=TGT_ENDPOINT, tools=[{"tool_id": "tgt-1", "tool_name": "demo"}]) + tgt = FakeClient( + endpoint=TGT_ENDPOINT, tools=[{"tool_id": "tgt-1", "tool_name": "demo"}] + ) remap = RemapTable() remap.record("custom_tool", "src-1", "tgt-1") ctx = _ctx(src, tgt, remap=remap) @@ -210,7 +210,9 @@ def test_oversize_file_is_recorded_and_siblings_continue(): ("src-1", "small.txt"): _text_payload("ok"), }, ) - tgt = FakeClient(endpoint=TGT_ENDPOINT, tools=[{"tool_id": "tgt-1", "tool_name": "demo"}]) + tgt = FakeClient( + endpoint=TGT_ENDPOINT, tools=[{"tool_id": "tgt-1", "tool_name": "demo"}] + ) remap = RemapTable() remap.record("custom_tool", "src-1", "tgt-1") ctx = _ctx(src, tgt, remap=remap, max_file_size=10) @@ -219,6 +221,10 @@ def test_oversize_file_is_recorded_and_siblings_continue(): result = FilesPhase(ctx).run(report) assert result.created == 1 + # Oversize must bump skipped so the operator sees it surfaced in the + # phase counters, not only in the report's list. + assert result.skipped == 1 + assert result.failed == 0 assert {u["file_name"] for u in tgt.uploaded} == {"small.txt"} assert len(report.oversize_files) == 1 over = report.oversize_files[0] @@ -238,7 +244,9 @@ def test_unsupported_mime_is_recorded_not_uploaded(): } }, ) - tgt = FakeClient(endpoint=TGT_ENDPOINT, tools=[{"tool_id": "tgt-1", "tool_name": "demo"}]) + tgt = FakeClient( + endpoint=TGT_ENDPOINT, tools=[{"tool_id": "tgt-1", "tool_name": "demo"}] + ) remap = RemapTable() remap.record("custom_tool", "src-1", "tgt-1") ctx = _ctx(src, tgt, remap=remap) @@ -247,6 +255,10 @@ def test_unsupported_mime_is_recorded_not_uploaded(): result = FilesPhase(ctx).run(report) assert result.created == 0 + # Unsupported mimes must bump skipped so the run doesn't report green + # while leaving files unmoved. + assert result.skipped == 1 + assert result.failed == 0 assert tgt.uploaded == [] assert len(report.unsupported_files) == 1 entry = report.unsupported_files[0] @@ -254,6 +266,29 @@ def test_unsupported_mime_is_recorded_not_uploaded(): assert entry["mime_type"].startswith("application/vnd.openxmlformats") +def test_malformed_source_dm_row_bumps_skipped_with_error(): + # Renamed-field or partial-serializer response: row lacks + # document_name/document_id. Must surface, not silently disappear. + src = FakeClient( + endpoint=SRC_ENDPOINT, + documents={"src-1": [{"tool": "src-1"}, _doc("ok.pdf")]}, + file_payloads={("src-1", "ok.pdf"): _pdf_payload(b"BYTES")}, + ) + tgt = FakeClient( + endpoint=TGT_ENDPOINT, tools=[{"tool_id": "tgt-1", "tool_name": "demo"}] + ) + remap = RemapTable() + remap.record("custom_tool", "src-1", "tgt-1") + ctx = _ctx(src, tgt, remap=remap) + report = CloneReport() + + result = FilesPhase(ctx).run(report) + + assert result.created == 1 # the well-formed sibling still uploads. + assert result.skipped == 1 # the malformed row. + assert any("malformed source DM row" in e for e in result.errors) + + def test_skip_strategy_emits_skipped_files_no_traffic(): src = FakeClient( endpoint=SRC_ENDPOINT, @@ -282,7 +317,9 @@ def test_dry_run_makes_no_writes_even_for_missing_files(): documents={"src-1": [_doc("a.pdf")]}, file_payloads={("src-1", "a.pdf"): _pdf_payload(b"X")}, ) - tgt = FakeClient(endpoint=TGT_ENDPOINT, tools=[{"tool_id": "tgt-1", "tool_name": "demo"}]) + tgt = FakeClient( + endpoint=TGT_ENDPOINT, tools=[{"tool_id": "tgt-1", "tool_name": "demo"}] + ) remap = RemapTable() remap.record("custom_tool", "src-1", "tgt-1") ctx = _ctx(src, tgt, remap=remap, dry_run=True) @@ -305,7 +342,9 @@ def test_transient_503_is_retried_then_succeeds(monkeypatch): src.download_errors[("src-1", "a.pdf")] = [ PlatformAPIError("flaky", status_code=503, body="") ] - tgt = FakeClient(endpoint=TGT_ENDPOINT, tools=[{"tool_id": "tgt-1", "tool_name": "demo"}]) + tgt = FakeClient( + endpoint=TGT_ENDPOINT, tools=[{"tool_id": "tgt-1", "tool_name": "demo"}] + ) remap = RemapTable() remap.record("custom_tool", "src-1", "tgt-1") ctx = _ctx(src, tgt, remap=remap) @@ -365,7 +404,9 @@ def test_upload_failure_records_failed_files_entry(): documents={"src-1": [_doc("a.pdf")]}, file_payloads={("src-1", "a.pdf"): _pdf_payload(b"X")}, ) - tgt = FakeClient(endpoint=TGT_ENDPOINT, tools=[{"tool_id": "tgt-1", "tool_name": "demo"}]) + tgt = FakeClient( + endpoint=TGT_ENDPOINT, tools=[{"tool_id": "tgt-1", "tool_name": "demo"}] + ) tgt.upload_errors[("tgt-1", "a.pdf")] = [ PlatformAPIError("bad", status_code=400, body="bad") ] @@ -397,7 +438,9 @@ def test_text_mimes_round_trip_as_utf8(mime, raw): documents={"src-1": [_doc("data")]}, file_payloads={("src-1", "data"): _text_payload(raw, mime=mime)}, ) - tgt = FakeClient(endpoint=TGT_ENDPOINT, tools=[{"tool_id": "tgt-1", "tool_name": "demo"}]) + tgt = FakeClient( + endpoint=TGT_ENDPOINT, tools=[{"tool_id": "tgt-1", "tool_name": "demo"}] + ) remap = RemapTable() remap.record("custom_tool", "src-1", "tgt-1") ctx = _ctx(src, tgt, remap=remap) @@ -470,9 +513,7 @@ def test_default_doc_preserves_existing_target_choice(): # Operator already picked a doc on target — re-run must not clobber. tgt = FakeClient( endpoint=TGT_ENDPOINT, - tools=[ - {"tool_id": "tgt-1", "tool_name": "demo", "output": "operator-pick"} - ], + tools=[{"tool_id": "tgt-1", "tool_name": "demo", "output": "operator-pick"}], ) remap = RemapTable() remap.record("custom_tool", "src-1", "tgt-1") diff --git a/tests/clone/test_orchestrator.py b/tests/clone/test_orchestrator.py new file mode 100644 index 0000000..b7a2626 --- /dev/null +++ b/tests/clone/test_orchestrator.py @@ -0,0 +1,159 @@ +"""End-to-end tests for the ``clone()`` orchestrator. + +Coverage: +- Phase ordering matches ``PHASES`` declaration. +- ``include`` / ``exclude`` route phases through ``skipped_phases``. +- ``CloneError`` raised by a phase aborts the run; subsequent phases skipped. +- Both ``PlatformClient`` instances are closed even when a phase aborts. +- ``RemapTable`` snapshot lands on the report. +""" + +from __future__ import annotations + +from unittest.mock import patch + +import pytest + +from unstract.clone import orchestrator +from unstract.clone.context import CloneOptions, OrgEndpoint +from unstract.clone.exceptions import CloneError +from unstract.clone.phases.base import Phase +from unstract.clone.report import CloneReport, PhaseResult + + +class _RecordingPhase(Phase): + """Per-test phase factory; records invocation order on a shared list.""" + + invocations: list[str] = [] + name = "" + + def run(self, report: CloneReport) -> PhaseResult: + _RecordingPhase.invocations.append(self.name) + result = report.get_phase(self.name) + result.created += 1 + # Drop a remap entry so we can prove the snapshot lands on the report. + self.ctx.remap.record(self.name, f"src-{self.name}", f"tgt-{self.name}") + return result + + +def _make_phase(phase_name: str) -> type[Phase]: + return type( + f"FakePhase_{phase_name}", + (_RecordingPhase,), + {"name": phase_name}, + ) + + +@pytest.fixture(autouse=True) +def _reset_invocations(): + _RecordingPhase.invocations = [] + yield + _RecordingPhase.invocations = [] + + +@pytest.fixture +def fake_phases(): + """Replace PHASES with a small deterministic set for the test run.""" + fake = [ + ("adapter", _make_phase("adapter")), + ("connector", _make_phase("connector")), + ("workflow", _make_phase("workflow")), + ] + with patch.object(orchestrator, "PHASES", fake): + yield fake + + +def _src() -> OrgEndpoint: + return OrgEndpoint( + base_url="https://src.example.com", + organization_id="src_org", + platform_key="src-key", + ) + + +def _tgt() -> OrgEndpoint: + return OrgEndpoint( + base_url="https://tgt.example.com", + organization_id="tgt_org", + platform_key="tgt-key", + ) + + +def test_phases_run_in_declared_order(fake_phases): + with patch.object(orchestrator.PlatformClient, "close") as mock_close: + report = orchestrator.clone(_src(), _tgt()) + assert _RecordingPhase.invocations == ["adapter", "connector", "workflow"] + assert [p.name for p in report.phases] == ["adapter", "connector", "workflow"] + # Both clients must close (source + target) regardless of outcome. + assert mock_close.call_count == 2 + + +def test_include_filter_only_runs_listed_phases(fake_phases): + opts = CloneOptions(include=("connector",)) + with patch.object(orchestrator.PlatformClient, "close"): + report = orchestrator.clone(_src(), _tgt(), opts) + assert _RecordingPhase.invocations == ["connector"] + assert set(report.skipped_phases) == {"adapter", "workflow"} + + +def test_exclude_filter_skips_listed_phases(fake_phases): + opts = CloneOptions(exclude=("workflow",)) + with patch.object(orchestrator.PlatformClient, "close"): + report = orchestrator.clone(_src(), _tgt(), opts) + assert _RecordingPhase.invocations == ["adapter", "connector"] + assert report.skipped_phases == ["workflow"] + + +def test_clone_error_aborts_and_skips_subsequent_phases(): + class AbortingPhase(Phase): + name = "connector" + + def run(self, report: CloneReport) -> PhaseResult: + raise CloneError("name collision in 'connector'") + + fake = [ + ("adapter", _make_phase("adapter")), + ("connector", AbortingPhase), + ("workflow", _make_phase("workflow")), + ] + with ( + patch.object(orchestrator, "PHASES", fake), + patch.object(orchestrator.PlatformClient, "close") as mock_close, + ): + report = orchestrator.clone(_src(), _tgt()) + + assert _RecordingPhase.invocations == ["adapter"] + assert report.aborted is True + assert "name collision" in report.abort_reason + # Clients still close on abort. + assert mock_close.call_count == 2 + + +def test_unrelated_exception_propagates_but_still_closes_clients(): + class CrashingPhase(Phase): + name = "connector" + + def run(self, report: CloneReport) -> PhaseResult: + raise RuntimeError("boom") + + fake = [ + ("adapter", _make_phase("adapter")), + ("connector", CrashingPhase), + ] + with ( + patch.object(orchestrator, "PHASES", fake), + patch.object(orchestrator.PlatformClient, "close") as mock_close, + ): + with pytest.raises(RuntimeError, match="boom"): + orchestrator.clone(_src(), _tgt()) + assert mock_close.call_count == 2 + + +def test_remap_snapshot_populated_on_report(fake_phases): + with patch.object(orchestrator.PlatformClient, "close"): + report = orchestrator.clone(_src(), _tgt()) + assert report.remap_snapshot == { + "adapter": {"src-adapter": "tgt-adapter"}, + "connector": {"src-connector": "tgt-connector"}, + "workflow": {"src-workflow": "tgt-workflow"}, + } diff --git a/tests/clone/test_pipeline_phase.py b/tests/clone/test_pipeline_phase.py index 982a960..e69c3cb 100644 --- a/tests/clone/test_pipeline_phase.py +++ b/tests/clone/test_pipeline_phase.py @@ -25,7 +25,6 @@ from unstract.clone.phases.pipeline import PipelinePhase from unstract.clone.report import CloneReport - PIPELINE_POST_SCHEMA = frozenset( { "pipeline_name", @@ -62,6 +61,12 @@ def list_pipelines( result = [p for p in result if p.get("pipeline_type") == pipeline_type] return list(result) + def get_pipeline(self, pipeline_id: str) -> dict: + for p in self.pipelines: + if p["id"] == pipeline_id: + return dict(p) + raise KeyError(pipeline_id) + def create_pipeline(self, payload: dict) -> dict: new = dict(payload) new["id"] = f"tgt-pipeline-{self._next:04d}" @@ -111,9 +116,7 @@ def _ctx(source, target, *, remap=None, **opt_overrides): def test_happy_path_creates_pipeline_with_remapped_workflow(): - src = FakeClient( - [_src_pipeline("src-pl-1", "Daily Invoices", "wf-src-1")] - ) + src = FakeClient([_src_pipeline("src-pl-1", "Daily Invoices", "wf-src-1")]) tgt = FakeClient() remap = RemapTable() remap.record("workflow", "wf-src-1", "wf-tgt-1") @@ -129,10 +132,52 @@ def test_happy_path_creates_pipeline_with_remapped_workflow(): assert ctx.remap.resolve("pipeline", "src-pl-1") == posted["id"] +def test_create_uses_per_id_get_not_stripped_list_payload(): + # list_pipelines can omit fields the create serializer expects. Phase + # must re-fetch the full record via get_pipeline before POSTing. + full = _src_pipeline("src-pl-1", "Daily Invoices", "wf-src-1") + full["cron_string"] = "0 5 * * *" # only present on detail serializer. + stripped = {k: v for k, v in full.items() if k not in ("cron_string",)} + + class StripListFakeClient(FakeClient): + def list_pipelines(self, *, name=None, pipeline_type=None): + base = ( + [stripped] + if ( + (name is None or stripped["pipeline_name"] == name) + and ( + pipeline_type is None + or stripped["pipeline_type"] == pipeline_type + ) + ) + else [] + ) + return list(base) + + def get_pipeline(self, pipeline_id): + assert pipeline_id == full["id"] + return dict(full) + + src = StripListFakeClient([full]) + tgt = FakeClient() + remap = RemapTable() + remap.record("workflow", "wf-src-1", "wf-tgt-1") + ctx = _ctx(src, tgt, remap=remap) + + PipelinePhase(ctx).run(CloneReport()) + + posted = tgt.posts[0] + # cron_string only existed on the detail GET — proves we did NOT + # POST the stripped list-item payload. + assert posted["cron_string"] == "0 5 * * *" + + def test_default_and_app_pipeline_types_are_skipped(): src = FakeClient( [ - _src_pipeline("src-1", "default-legacy", "wf-src-1", pipeline_type="DEFAULT"), + _src_pipeline( + "src-1", "default-legacy", "wf-src-1", pipeline_type="DEFAULT" + ), _src_pipeline("src-2", "streamlit-app", "wf-src-1", pipeline_type="APP"), _src_pipeline("src-3", "real-etl", "wf-src-1", pipeline_type="ETL"), ] @@ -151,9 +196,7 @@ def test_default_and_app_pipeline_types_are_skipped(): def test_adopts_existing_pipeline_by_name(): src = FakeClient([_src_pipeline("src-pl-1", "Daily Invoices", "wf-src-1")]) - tgt = FakeClient( - [{"id": "tgt-existing", "pipeline_name": "Daily Invoices"}] - ) + tgt = FakeClient([{"id": "tgt-existing", "pipeline_name": "Daily Invoices"}]) remap = RemapTable() remap.record("workflow", "wf-src-1", "wf-tgt-1") ctx = _ctx(src, tgt, remap=remap) diff --git a/tests/clone/test_tool_instance_phase.py b/tests/clone/test_tool_instance_phase.py index eb47405..61e59a1 100644 --- a/tests/clone/test_tool_instance_phase.py +++ b/tests/clone/test_tool_instance_phase.py @@ -43,9 +43,7 @@ def create_tool_instance(self, payload: dict) -> dict: self.create_calls.append(new) return new - def update_tool_instance_metadata( - self, instance_id: str, metadata: dict - ) -> dict: + def update_tool_instance_metadata(self, instance_id: str, metadata: dict) -> dict: self.patch_calls.append((instance_id, metadata)) for wf_instances in self.instances.values(): for ti in wf_instances: @@ -91,7 +89,9 @@ def test_happy_path_creates_instance_then_patches_metadata(): src = FakeClient() src.instances[SRC_WF] = [ _src_ti( - "src-ti-1", SRC_WF, SRC_REG, + "src-ti-1", + SRC_WF, + SRC_REG, { "llm": "My OpenAI", "embedding": "MyEmb", @@ -114,12 +114,17 @@ def test_happy_path_creates_instance_then_patches_metadata(): posted = tgt.create_calls[0] assert posted["workflow_id"] == TGT_WF assert posted["tool_id"] == TGT_REG - # PATCH carries the source settings but never the source-internal - # identity fields — the target row already has its own. + # PATCH carries source settings but stamps identity fields with + # target values — backend PATCH overwrites the whole metadata dict. assert len(tgt.patch_calls) == 1 patched_id, patched_metadata = tgt.patch_calls[0] assert patched_id == posted["id"] - assert patched_metadata == {"llm": "My OpenAI", "embedding": "MyEmb"} + assert patched_metadata == { + "llm": "My OpenAI", + "embedding": "MyEmb", + "prompt_registry_id": TGT_REG, + "tool_instance_id": posted["id"], + } assert ctx.remap.resolve("tool_instance", "src-ti-1") == posted["id"] @@ -154,8 +159,18 @@ def test_adopt_existing_target_instance_and_repatch_metadata(): assert result.adopted == 1 assert result.created == 0 assert tgt.create_calls == [] - # PATCH still fires for the adopted instance to align metadata. - assert tgt.patch_calls == [("tgt-pre-ti", src_meta)] + # PATCH still fires on adopt and stamps identity fields with target + # values so the runtime can resolve the registry. + assert tgt.patch_calls == [ + ( + "tgt-pre-ti", + { + "llm": "My OpenAI", + "prompt_registry_id": TGT_REG, + "tool_instance_id": "tgt-pre-ti", + }, + ) + ] assert ctx.remap.resolve("tool_instance", "src-ti-1") == "tgt-pre-ti" @@ -182,3 +197,31 @@ def test_dry_run_does_not_create_or_patch(): assert result.skipped == 1 assert tgt.create_calls == [] assert tgt.patch_calls == [] + + +def test_dry_run_on_adopt_path_does_not_repatch_target(): + # Target already has a tool_instance for the target workflow. On a + # dry-run, we must NOT PATCH its metadata — the adopt branch used to + # fall through to the PATCH call. + src = FakeClient() + src.instances[SRC_WF] = [_src_ti("src-ti-1", SRC_WF, SRC_REG, {"llm": "My OpenAI"})] + tgt = FakeClient() + tgt.instances[TGT_WF] = [ + { + "id": "tgt-pre-ti", + "workflow": TGT_WF, + "tool_id": TGT_REG, + "metadata": {"existing": "untouched"}, + "step": 1, + } + ] + ctx = _ctx(src, tgt, remap=_seed_remap(), dry_run=True) + + result = ToolInstancePhase(ctx).run(CloneReport()) + + assert result.skipped == 1 + assert result.adopted == 0 + assert tgt.create_calls == [] + assert tgt.patch_calls == [] + # Remap still gets recorded so downstream dry-run output is coherent. + assert ctx.remap.resolve("tool_instance", "src-ti-1") == "tgt-pre-ti" diff --git a/tests/clone/test_workflow_endpoint_phase.py b/tests/clone/test_workflow_endpoint_phase.py index 1df2837..811488f 100644 --- a/tests/clone/test_workflow_endpoint_phase.py +++ b/tests/clone/test_workflow_endpoint_phase.py @@ -24,16 +24,12 @@ def __init__(self) -> None: self.endpoints: dict[str, list[dict]] = {} self.patch_calls: list[tuple[str, dict]] = [] - def list_workflow_endpoints( - self, *, workflow_id: str | None = None - ) -> list[dict]: + def list_workflow_endpoints(self, *, workflow_id: str | None = None) -> list[dict]: if workflow_id is None: return [ep for eps in self.endpoints.values() for ep in eps] return list(self.endpoints.get(workflow_id, [])) - def update_workflow_endpoint( - self, endpoint_id: str, payload: dict - ) -> dict: + def update_workflow_endpoint(self, endpoint_id: str, payload: dict) -> dict: self.patch_calls.append((endpoint_id, payload)) for eps in self.endpoints.values(): for ep in eps: @@ -131,6 +127,34 @@ def test_pairs_endpoints_by_type_and_remaps_connector(): assert ctx.remap.resolve("workflow_endpoint", "src-ep-dest") == "tgt-ep-dest" +def test_endpoint_with_null_connection_type_omits_key_in_payload(): + # Source had connection_type=None (rare but legal on the model). + # Must NOT coerce to "" — backend treats blank as a validation + # failure on the enum. Omit the key entirely so backend keeps the + # existing target value. + src = FakeClient() + src.endpoints[SRC_WF] = [ + { + "id": "src-ep-source", + "workflow": SRC_WF, + "endpoint_type": "SOURCE", + "connection_type": None, + "configuration": {}, + "connector_instance": None, + } + ] + tgt = FakeClient() + tgt.endpoints[TGT_WF] = [_tgt_endpoint("tgt-ep-source", "SOURCE")] + ctx = _ctx(src, tgt, remap=_seed_remap()) + + result = WorkflowEndpointPhase(ctx).run(CloneReport()) + + assert result.failed == 0 + assert len(tgt.patch_calls) == 1 + _, payload = tgt.patch_calls[0] + assert "connection_type" not in payload + + def test_endpoint_without_source_connector_patches_with_null(): src = FakeClient() src.endpoints[SRC_WF] = [ @@ -183,9 +207,7 @@ def test_unknown_connector_uuid_skips_endpoint_and_flags_error(): def test_missing_target_endpoint_fails_loudly(): src = FakeClient() - src.endpoints[SRC_WF] = [ - _src_endpoint("src-ep-source", "SOURCE", SRC_CONN, {}) - ] + src.endpoints[SRC_WF] = [_src_endpoint("src-ep-source", "SOURCE", SRC_CONN, {})] tgt = FakeClient() tgt.endpoints[TGT_WF] = [] # No endpoints — anomaly. ctx = _ctx(src, tgt, remap=_seed_remap()) @@ -198,9 +220,7 @@ def test_missing_target_endpoint_fails_loudly(): def test_dry_run_makes_no_patches(): src = FakeClient() - src.endpoints[SRC_WF] = [ - _src_endpoint("src-ep-source", "SOURCE", SRC_CONN, {}) - ] + src.endpoints[SRC_WF] = [_src_endpoint("src-ep-source", "SOURCE", SRC_CONN, {})] tgt = FakeClient() tgt.endpoints[TGT_WF] = [_tgt_endpoint("tgt-ep-source", "SOURCE")] ctx = _ctx(src, tgt, remap=_seed_remap(), dry_run=True) From 5e25736da01b9591b79b07375867071c2c2f0575 Mon Sep 17 00:00:00 2001 From: Chandrasekharan M Date: Wed, 27 May 2026 09:07:56 +0530 Subject: [PATCH 24/25] perf(clone): within-phase parallelism + per-phase timing Adds --concurrency N (1-32, default 4) to fan out per-phase work across threads. Phases mutate shared state (counters, RemapTable, file lists) under a lock owned by the parallel_map helper. concurrency=1 short- circuits the executor for byte-for-byte identical sequential behaviour. Also: - PlatformAPIError now surfaces response body in str(e) so logger.exception emits the backend error text. - CloneReport tracks per-phase + total wall-clock; rendered as a Time column in the run report. - Files phase restructured into 3 passes (prep sequentially, per-file download/upload in parallel, set default doc sequentially). --- src/unstract/clone/cli.py | 12 +- src/unstract/clone/context.py | 3 + src/unstract/clone/exceptions.py | 7 +- src/unstract/clone/orchestrator.py | 13 + src/unstract/clone/phases/adapter.py | 61 ++-- src/unstract/clone/phases/api_deployment.py | 48 ++- src/unstract/clone/phases/base.py | 68 +++- src/unstract/clone/phases/connector.py | 55 ++-- src/unstract/clone/phases/custom_tool.py | 101 +++--- src/unstract/clone/phases/files.py | 205 ++++++------ src/unstract/clone/phases/pipeline.py | 48 ++- src/unstract/clone/phases/tag.py | 33 +- src/unstract/clone/phases/tool_instance.py | 73 +++-- src/unstract/clone/phases/workflow.py | 43 ++- .../clone/phases/workflow_endpoint.py | 76 +++-- src/unstract/clone/report.py | 61 +++- tests/clone/test_phase_concurrency.py | 291 ++++++++++++++++++ 17 files changed, 888 insertions(+), 310 deletions(-) create mode 100644 tests/clone/test_phase_concurrency.py diff --git a/src/unstract/clone/cli.py b/src/unstract/clone/cli.py index 6b7a0b4..d2ed358 100644 --- a/src/unstract/clone/cli.py +++ b/src/unstract/clone/cli.py @@ -16,6 +16,7 @@ import click from unstract.clone.context import ( + DEFAULT_CONCURRENCY, DEFAULT_MAX_FILE_SIZE, CloneOptions, OrgEndpoint, @@ -133,6 +134,13 @@ def cli() -> None: is_flag=True, help="Alias for --file-strategy=skip.", ) +@click.option( + "--concurrency", + type=click.IntRange(min=1, max=32), + default=DEFAULT_CONCURRENCY, + show_default=True, + help="Per-phase worker count. 1 = strictly sequential.", +) @click.option("-v", "--verbose", is_flag=True, help="Debug logging") def clone_cmd( source_url: str, @@ -149,6 +157,7 @@ def clone_cmd( file_strategy: str, max_file_size: str, skip_files: bool, + concurrency: int, verbose: bool, ) -> None: """Clone configured resources from one org to another.""" @@ -167,9 +176,8 @@ def clone_cmd( on_name_conflict=on_name_conflict, verbose=verbose, file_strategy=effective_strategy, - # Distinguish "user said 0" (force every file to manual list) from - # an unparseable size — `_parse_size` raises in the latter case. max_file_size=cap_bytes if cap_bytes is not None else DEFAULT_MAX_FILE_SIZE, + concurrency=concurrency, ) source = OrgEndpoint( diff --git a/src/unstract/clone/context.py b/src/unstract/clone/context.py index 5a56be4..e5d6b0f 100644 --- a/src/unstract/clone/context.py +++ b/src/unstract/clone/context.py @@ -35,6 +35,7 @@ class OrgEndpoint: DEFAULT_MAX_FILE_SIZE = 25 * 1024 * 1024 # 25 MB; oversize → manual-upload list +DEFAULT_CONCURRENCY = 4 @dataclass @@ -50,6 +51,8 @@ class CloneOptions: # "skip": metadata only; operator re-uploads via UI on target. file_strategy: str = "platform_api" max_file_size: int = DEFAULT_MAX_FILE_SIZE + # Per-phase worker fan-out. 1 = sequential (no executor). + concurrency: int = DEFAULT_CONCURRENCY def includes(self, phase_name: str) -> bool: if self.include is not None and phase_name not in self.include: diff --git a/src/unstract/clone/exceptions.py b/src/unstract/clone/exceptions.py index 47bffe2..3933c1c 100644 --- a/src/unstract/clone/exceptions.py +++ b/src/unstract/clone/exceptions.py @@ -8,8 +8,11 @@ class CloneError(Exception): class PlatformAPIError(CloneError): """Raised when the Platform API returns a non-2xx response we can't recover from.""" - def __init__(self, message: str, status_code: int | None = None, body: str | None = None): - super().__init__(message) + def __init__( + self, message: str, status_code: int | None = None, body: str | None = None + ): + full_message = f"{message}\n body: {body}" if body else message + super().__init__(full_message) self.status_code = status_code self.body = body diff --git a/src/unstract/clone/orchestrator.py b/src/unstract/clone/orchestrator.py index 552c0b0..a7c81a0 100644 --- a/src/unstract/clone/orchestrator.py +++ b/src/unstract/clone/orchestrator.py @@ -12,6 +12,7 @@ from __future__ import annotations import logging +import time from unstract.clone.client import PlatformClient from unstract.clone.context import CloneContext, CloneOptions, OrgEndpoint @@ -81,20 +82,32 @@ def clone( ), ) + run_started = time.perf_counter() for name, phase_cls in PHASES: if not opts.includes(name): report.skipped_phases.append(name) logger.info("Phase '%s' skipped (excluded)", name) continue logger.info("=== Phase: %s ===", name) + phase_started = time.perf_counter() try: phase_cls(ctx).run(report) except CloneError as e: report.aborted = True report.abort_reason = str(e) logger.error("Phase '%s' aborted: %s", name, e) + # Stamp duration even on abort so the report reflects time spent. + report.get_phase(name).duration_s = time.perf_counter() - phase_started break + else: + report.get_phase(name).duration_s = time.perf_counter() - phase_started + logger.info( + "=== Phase '%s' done in %.2fs ===", + name, + report.get_phase(name).duration_s, + ) + report.total_duration_s = time.perf_counter() - run_started report.remap_snapshot = ctx.remap.snapshot() return report finally: diff --git a/src/unstract/clone/phases/adapter.py b/src/unstract/clone/phases/adapter.py index 9fb5668..522629f 100644 --- a/src/unstract/clone/phases/adapter.py +++ b/src/unstract/clone/phases/adapter.py @@ -12,6 +12,7 @@ from __future__ import annotations import logging +import threading from typing import Any from unstract.clone.exceptions import NameConflictError @@ -44,29 +45,38 @@ def run(self, report: CloneReport) -> PhaseResult: return result logger.info("Found %d adapter(s) in source org", len(src_summaries)) - for summary in src_summaries: - self._clone_one(summary, result) + self.parallel_map( + src_summaries, + lambda summary, lock: self._clone_one(summary, result, lock), + ) return result - def _clone_one(self, summary: dict[str, Any], result: PhaseResult) -> None: + def _clone_one( + self, summary: dict[str, Any], result: PhaseResult, lock: threading.Lock + ) -> None: name = summary["adapter_name"] atype = summary["adapter_type"] src_id = summary["id"] - # List response omits adapter_metadata; fetch detail to pick it up. try: src = self.ctx.source.get_adapter(src_id) except Exception as e: - logger.exception("Failed to GET source adapter %s [%s] detail: %s", name, atype, e) - result.failed += 1 - result.errors.append(f"GET source detail {name} [{atype}]: {e}") + logger.exception( + "Failed to GET source adapter %s [%s] detail: %s", name, atype, e + ) + with lock: + result.failed += 1 + result.errors.append(f"GET source detail {name} [{atype}]: {e}") return try: existing = self.ctx.target.list_adapters(name=name, adapter_type=atype) except Exception as e: - logger.exception("Failed to GET adapter %s [%s] on target: %s", name, atype, e) - result.failed += 1 - result.errors.append(f"GET {name} [{atype}]: {e}") + logger.exception( + "Failed to GET adapter %s [%s] on target: %s", name, atype, e + ) + with lock: + result.failed += 1 + result.errors.append(f"GET {name} [{atype}]: {e}") return if existing: @@ -75,14 +85,21 @@ def _clone_one(self, summary: dict[str, Any], result: PhaseResult) -> None: raise NameConflictError( f"adapter '{name}' [{atype}] already exists in target as {tgt['id']}" ) - result.adopted += 1 + with lock: + result.adopted += 1 logger.info( "adopted adapter '%s' [%s] src=%s -> tgt=%s", - name, atype, src_id, tgt["id"], + name, + atype, + src_id, + tgt["id"], ) elif self.ctx.options.dry_run: - result.skipped += 1 - logger.info("[dry-run] would create adapter '%s' [%s] src=%s", name, atype, src_id) + with lock: + result.skipped += 1 + logger.info( + "[dry-run] would create adapter '%s' [%s] src=%s", name, atype, src_id + ) return else: payload = build_post_payload(src, self._writable) @@ -90,13 +107,19 @@ def _clone_one(self, summary: dict[str, Any], result: PhaseResult) -> None: tgt = self.ctx.target.create_adapter(payload) except Exception as e: logger.exception("Failed to create adapter %s [%s]: %s", name, atype, e) - result.failed += 1 - result.errors.append(f"create {name} [{atype}]: {e}") + with lock: + result.failed += 1 + result.errors.append(f"create {name} [{atype}]: {e}") return - result.created += 1 + with lock: + result.created += 1 logger.info( "created adapter '%s' [%s] src=%s -> tgt=%s", - name, atype, src_id, tgt["id"], + name, + atype, + src_id, + tgt["id"], ) - self.ctx.remap.record("adapter", src_id, tgt["id"]) + with lock: + self.ctx.remap.record("adapter", src_id, tgt["id"]) diff --git a/src/unstract/clone/phases/api_deployment.py b/src/unstract/clone/phases/api_deployment.py index bbd20ed..df55983 100644 --- a/src/unstract/clone/phases/api_deployment.py +++ b/src/unstract/clone/phases/api_deployment.py @@ -14,6 +14,7 @@ from __future__ import annotations import logging +import threading from typing import Any from unstract.clone.exceptions import NameConflictError @@ -50,11 +51,15 @@ def run(self, report: CloneReport) -> PhaseResult: return result logger.info("Found %d source API deployment(s)", len(src_deployments)) - for src in src_deployments: - self._clone_one(src, result) + self.parallel_map( + src_deployments, + lambda src, lock: self._clone_one(src, result, lock), + ) return result - def _clone_one(self, src: dict[str, Any], result: PhaseResult) -> None: + def _clone_one( + self, src: dict[str, Any], result: PhaseResult, lock: threading.Lock + ) -> None: api_name = src["api_name"] src_id = src["id"] src_wf_id = src.get("workflow") or src.get("workflow_id") @@ -63,17 +68,20 @@ def _clone_one(self, src: dict[str, Any], result: PhaseResult) -> None: logger.warning( "source api_deployment '%s' has no workflow FK — skipping", api_name ) - result.skipped += 1 + with lock: + result.skipped += 1 return - tgt_wf_id = self.ctx.remap.resolve("workflow", src_wf_id) + with lock: + tgt_wf_id = self.ctx.remap.resolve("workflow", src_wf_id) if not tgt_wf_id: logger.warning( "no workflow remap for api_deployment '%s' (src workflow %s) — skipping", api_name, src_wf_id, ) - result.skipped += 1 + with lock: + result.skipped += 1 return try: @@ -82,8 +90,9 @@ def _clone_one(self, src: dict[str, Any], result: PhaseResult) -> None: logger.exception( "Failed to GET api_deployment %s on target: %s", api_name, e ) - result.failed += 1 - result.errors.append(f"GET {api_name}: {e}") + with lock: + result.failed += 1 + result.errors.append(f"GET {api_name}: {e}") return if existing: @@ -92,7 +101,8 @@ def _clone_one(self, src: dict[str, Any], result: PhaseResult) -> None: raise NameConflictError( f"api_deployment '{api_name}' already exists in target as {tgt['id']}" ) - result.adopted += 1 + with lock: + result.adopted += 1 logger.info( "adopted api_deployment '%s' src=%s -> tgt=%s", api_name, @@ -100,21 +110,22 @@ def _clone_one(self, src: dict[str, Any], result: PhaseResult) -> None: tgt["id"], ) elif self.ctx.options.dry_run: - result.skipped += 1 + with lock: + result.skipped += 1 logger.info( "[dry-run] would create api_deployment '%s' src=%s", api_name, src_id ) return else: try: - # list serializer can strip fields the create serializer expects. full_src = self.ctx.source.get_api_deployment(src_id) except Exception as e: logger.exception( "Failed to GET source api_deployment %s: %s", api_name, e ) - result.failed += 1 - result.errors.append(f"GET src api_deployment {api_name}: {e}") + with lock: + result.failed += 1 + result.errors.append(f"GET src api_deployment {api_name}: {e}") return remapped = remap_uuids(full_src, self.ctx.remap) payload = build_post_payload(remapped, self._writable) @@ -123,10 +134,12 @@ def _clone_one(self, src: dict[str, Any], result: PhaseResult) -> None: tgt = self.ctx.target.create_api_deployment(payload) except Exception as e: logger.exception("Failed to create api_deployment %s: %s", api_name, e) - result.failed += 1 - result.errors.append(f"create {api_name}: {e}") + with lock: + result.failed += 1 + result.errors.append(f"create {api_name}: {e}") return - result.created += 1 + with lock: + result.created += 1 logger.info( "created api_deployment '%s' src=%s -> tgt=%s", api_name, @@ -135,7 +148,8 @@ def _clone_one(self, src: dict[str, Any], result: PhaseResult) -> None: ) self._warn_if_extra_source_keys(src_id, api_name) - self.ctx.remap.record("api_deployment", src_id, tgt["id"]) + with lock: + self.ctx.remap.record("api_deployment", src_id, tgt["id"]) def _warn_if_extra_source_keys(self, src_deployment_id: str, name: str) -> None: try: diff --git a/src/unstract/clone/phases/base.py b/src/unstract/clone/phases/base.py index 6208192..c14f1bc 100644 --- a/src/unstract/clone/phases/base.py +++ b/src/unstract/clone/phases/base.py @@ -3,12 +3,18 @@ from __future__ import annotations import logging +import threading from abc import ABC, abstractmethod -from typing import Any +from collections.abc import Callable, Iterable +from concurrent.futures import FIRST_EXCEPTION, Future, ThreadPoolExecutor, wait +from typing import Any, TypeVar from unstract.clone.context import CloneContext +from unstract.clone.exceptions import CloneError from unstract.clone.report import CloneReport, PhaseResult +T = TypeVar("T") + logger = logging.getLogger(__name__) # DRF OPTIONS reports any ModelSerializer FK/M2M as writable, but the @@ -31,9 +37,7 @@ ) -def build_post_payload( - src: dict[str, Any], writable: frozenset[str] -) -> dict[str, Any]: +def build_post_payload(src: dict[str, Any], writable: frozenset[str]) -> dict[str, Any]: """Project ``src`` onto the writable schema, dropping server-managed fields, ``None`` values, and empty strings (which DRF treats as blank and rejects on required fields). @@ -43,11 +47,7 @@ def build_post_payload( # 0 in (None, "") is False, but `0 not in (...)` falsely returns True). # Explicit identity / equality checks preserve falsy-but-meaningful # values like ``BooleanField`` False and numeric defaults. - return { - k: src[k] - for k in keys - if k in src and src[k] is not None and src[k] != "" - } + return {k: src[k] for k in keys if k in src and src[k] is not None and src[k] != ""} class Phase(ABC): @@ -62,3 +62,53 @@ def __init__(self, ctx: CloneContext): def run(self, report: CloneReport) -> PhaseResult: """Migrate all entities of this phase's type. Idempotent across runs.""" raise NotImplementedError + + def parallel_map( + self, + items: Iterable[T], + work_fn: Callable[[T, threading.Lock], None], + ) -> None: + """Fan ``work_fn(item, lock)`` across ``ctx.options.concurrency`` + threads. ``work_fn`` must hold ``lock`` while mutating shared + state. ``CloneError`` from any worker cancels the rest and + re-raises. ``concurrency <= 1`` skips the executor entirely. + """ + materialised = list(items) + if not materialised: + return + + concurrency = max(1, self.ctx.options.concurrency) + lock = threading.Lock() + + if concurrency == 1: + for item in materialised: + work_fn(item, lock) + return + + with ThreadPoolExecutor( + max_workers=concurrency, + thread_name_prefix=f"clone-{self.name}", + ) as pool: + futures: list[Future[None]] = [ + pool.submit(work_fn, item, lock) for item in materialised + ] + done, _ = wait(futures, return_when=FIRST_EXCEPTION) + clone_err: CloneError | None = None + other_err: BaseException | None = None + for fut in done: + if fut.cancelled(): + continue + exc = fut.exception() + if exc is None: + continue + if isinstance(exc, CloneError) and clone_err is None: + clone_err = exc + elif other_err is None: + other_err = exc + if clone_err is not None or other_err is not None: + for fut in futures: + fut.cancel() + if clone_err is not None: + raise clone_err + if other_err is not None: + raise other_err diff --git a/src/unstract/clone/phases/connector.py b/src/unstract/clone/phases/connector.py index 816456f..5a1c173 100644 --- a/src/unstract/clone/phases/connector.py +++ b/src/unstract/clone/phases/connector.py @@ -20,6 +20,7 @@ from __future__ import annotations import logging +import threading from typing import Any from unstract.clone.exceptions import NameConflictError @@ -52,11 +53,15 @@ def run(self, report: CloneReport) -> PhaseResult: return result logger.info("Found %d connector(s) in source org", len(src_summaries)) - for summary in src_summaries: - self._clone_one(summary, result) + self.parallel_map( + src_summaries, + lambda summary, lock: self._clone_one(summary, result, lock), + ) return result - def _clone_one(self, summary: dict[str, Any], result: PhaseResult) -> None: + def _clone_one( + self, summary: dict[str, Any], result: PhaseResult, lock: threading.Lock + ) -> None: name = summary["connector_name"] src_id = summary["id"] @@ -64,26 +69,29 @@ def _clone_one(self, summary: dict[str, Any], result: PhaseResult) -> None: src = self.ctx.source.get_connector(src_id) except Exception as e: logger.exception("Failed to GET source connector %s detail: %s", name, e) - result.failed += 1 - result.errors.append(f"GET source detail {name}: {e}") + with lock: + result.failed += 1 + result.errors.append(f"GET source detail {name}: {e}") return - # Empty metadata means the backend redacted it (auto-provisioned rows - # like Unstract Cloud Storage). We cannot reconstruct it on target. if not src.get("connector_metadata"): logger.info( "skipping connector '%s' (src=%s, catalog=%s) — source returned no metadata", - name, src_id, src.get("connector_id"), + name, + src_id, + src.get("connector_id"), ) - result.skipped += 1 + with lock: + result.skipped += 1 return try: existing = self.ctx.target.list_connectors(name=name) except Exception as e: logger.exception("Failed to GET connector %s on target: %s", name, e) - result.failed += 1 - result.errors.append(f"GET {name}: {e}") + with lock: + result.failed += 1 + result.errors.append(f"GET {name}: {e}") return if existing: @@ -92,13 +100,17 @@ def _clone_one(self, summary: dict[str, Any], result: PhaseResult) -> None: raise NameConflictError( f"connector '{name}' already exists in target as {tgt['id']}" ) - result.adopted += 1 + with lock: + result.adopted += 1 logger.info( "adopted connector '%s' src=%s -> tgt=%s", - name, src_id, tgt["id"], + name, + src_id, + tgt["id"], ) elif self.ctx.options.dry_run: - result.skipped += 1 + with lock: + result.skipped += 1 logger.info("[dry-run] would create connector '%s' src=%s", name, src_id) return else: @@ -107,13 +119,18 @@ def _clone_one(self, summary: dict[str, Any], result: PhaseResult) -> None: tgt = self.ctx.target.create_connector(payload) except Exception as e: logger.exception("Failed to create connector %s: %s", name, e) - result.failed += 1 - result.errors.append(f"create {name}: {e}") + with lock: + result.failed += 1 + result.errors.append(f"create {name}: {e}") return - result.created += 1 + with lock: + result.created += 1 logger.info( "created connector '%s' src=%s -> tgt=%s", - name, src_id, tgt["id"], + name, + src_id, + tgt["id"], ) - self.ctx.remap.record("connector", src_id, tgt["id"]) + with lock: + self.ctx.remap.record("connector", src_id, tgt["id"]) diff --git a/src/unstract/clone/phases/custom_tool.py b/src/unstract/clone/phases/custom_tool.py index f633e74..2c80112 100644 --- a/src/unstract/clone/phases/custom_tool.py +++ b/src/unstract/clone/phases/custom_tool.py @@ -28,6 +28,7 @@ from __future__ import annotations import logging +import threading from typing import Any from unstract.clone.exceptions import NameConflictError @@ -70,10 +71,6 @@ def run(self, report: CloneReport) -> PhaseResult: return result logger.info("Found %d custom tool(s) in source org", len(src_tools)) - # Fetch the target list once — name-based adoption lookup is - # done per source tool, but the underlying list is invariant - # across the loop barring our own creates (which we splice into - # ``target_tools`` after each create so re-runs stay idempotent). try: target_tools = self.ctx.target.list_custom_tools() except Exception as e: @@ -81,15 +78,27 @@ def run(self, report: CloneReport) -> PhaseResult: result.failed += 1 result.errors.append(f"list target tools: {e}") return result - for summary in src_tools: - self._clone_one(summary, target_tools, result) + + # Updated under lock when a fresh create lands so duplicate + # same-name source rows adopt instead of recreating. + target_by_name: dict[str, dict[str, Any]] = { + t["tool_name"]: t for t in target_tools + } + + self.parallel_map( + src_tools, + lambda summary, lock: self._clone_one( + summary, target_by_name, result, lock + ), + ) return result def _clone_one( self, summary: dict[str, Any], - target_tools: list[dict[str, Any]], + target_by_name: dict[str, dict[str, Any]], result: PhaseResult, + lock: threading.Lock, ) -> None: tool_name = summary["tool_name"] src_tool_id = summary["tool_id"] @@ -98,30 +107,34 @@ def _clone_one( export_data = self.ctx.source.export_project(src_tool_id) except Exception as e: logger.exception("Failed to export source tool '%s': %s", tool_name, e) - result.failed += 1 - result.errors.append(f"export src tool {tool_name}: {e}") + with lock: + result.failed += 1 + result.errors.append(f"export src tool {tool_name}: {e}") return - match = next((t for t in target_tools if t["tool_name"] == tool_name), None) + with lock: + match = target_by_name.get(tool_name) if match is not None: tgt_tool_id = self._adopt( - match, export_data, result, tool_name, src_tool_id + match, export_data, result, tool_name, src_tool_id, lock ) else: tgt_tool_id = self._create_fresh( - export_data, src_tool_id, tool_name, result + export_data, src_tool_id, tool_name, result, lock ) - # Keep the local cache in sync so a downstream source tool - # with the same name (uncommon but legal) adopts this new - # row instead of trying to re-create it. if tgt_tool_id is not None: - target_tools.append({"tool_id": tgt_tool_id, "tool_name": tool_name}) + with lock: + target_by_name[tool_name] = { + "tool_id": tgt_tool_id, + "tool_name": tool_name, + } if tgt_tool_id is None: return - self.ctx.remap.record("custom_tool", src_tool_id, tgt_tool_id) + with lock: + self.ctx.remap.record("custom_tool", src_tool_id, tgt_tool_id) if self.ctx.options.dry_run: return @@ -133,8 +146,9 @@ def _clone_one( ) except Exception as e: logger.exception("Registry republish failed for tool %s: %s", tool_name, e) - result.failed += 1 - result.errors.append(f"export {tool_name}: {e}") + with lock: + result.failed += 1 + result.errors.append(f"export {tool_name}: {e}") return try: @@ -147,16 +161,18 @@ def _clone_one( tool_name, e, ) - result.failed += 1 - result.errors.append(f"registry remap lookup {tool_name}: {e}") + with lock: + result.failed += 1 + result.errors.append(f"registry remap lookup {tool_name}: {e}") return if src_regs and tgt_regs: - self.ctx.remap.record( - "prompt_studio_registry", - src_regs[0]["prompt_registry_id"], - tgt_regs[0]["prompt_registry_id"], - ) + with lock: + self.ctx.remap.record( + "prompt_studio_registry", + src_regs[0]["prompt_registry_id"], + tgt_regs[0]["prompt_registry_id"], + ) def _adopt( self, @@ -165,6 +181,7 @@ def _adopt( result: PhaseResult, tool_name: str, src_tool_id: str, + lock: threading.Lock, ) -> str | None: if self.ctx.options.on_name_conflict == "abort": raise NameConflictError( @@ -173,7 +190,8 @@ def _adopt( tgt_tool_id = match["tool_id"] if self.ctx.options.dry_run: - result.skipped += 1 + with lock: + result.skipped += 1 logger.info( "[dry-run] would sync prompts into adopted tool '%s' src=%s -> tgt=%s", tool_name, @@ -186,11 +204,13 @@ def _adopt( self.ctx.target.sync_prompts(tgt_tool_id, export_data) except Exception as e: logger.exception("sync_prompts failed for tool %s: %s", tool_name, e) - result.failed += 1 - result.errors.append(f"sync {tool_name}: {e}") + with lock: + result.failed += 1 + result.errors.append(f"sync {tool_name}: {e}") return None - result.adopted += 1 + with lock: + result.adopted += 1 logger.info( "adopted tool '%s' src=%s -> tgt=%s (prompts re-synced)", tool_name, @@ -205,9 +225,11 @@ def _create_fresh( src_tool_id: str, tool_name: str, result: PhaseResult, + lock: threading.Lock, ) -> str | None: if self.ctx.options.dry_run: - result.skipped += 1 + with lock: + result.skipped += 1 logger.info( "[dry-run] would import tool '%s' src=%s", tool_name, src_tool_id ) @@ -215,22 +237,25 @@ def _create_fresh( adapter_ids = self._resolve_target_adapter_ids(src_tool_id, tool_name) if adapter_ids is None: - result.failed += 1 - result.errors.append( - f"import {tool_name}: missing target adapter remap for default profile" - ) + with lock: + result.failed += 1 + result.errors.append( + f"import {tool_name}: missing target adapter remap for default profile" + ) return None try: tgt = self.ctx.target.import_project(export_data, adapter_ids=adapter_ids) except Exception as e: logger.exception("import_project failed for tool %s: %s", tool_name, e) - result.failed += 1 - result.errors.append(f"import {tool_name}: {e}") + with lock: + result.failed += 1 + result.errors.append(f"import {tool_name}: {e}") return None tgt_tool_id = tgt["tool_id"] - result.created += 1 + with lock: + result.created += 1 logger.info( "created tool '%s' src=%s -> tgt=%s (needs_adapter_config=%s)", tool_name, diff --git a/src/unstract/clone/phases/files.py b/src/unstract/clone/phases/files.py index d28ce6f..8404651 100644 --- a/src/unstract/clone/phases/files.py +++ b/src/unstract/clone/phases/files.py @@ -18,16 +18,16 @@ - No download/upload. Source DM list is emitted into ``skipped_files`` so the operator knows what to re-upload manually via UI. -Concurrency is 1 per phase by design — the Platform API endpoint holds a -cloud worker for the whole upload, and uploads are not chunked on the BE -helper today. +Per-file work fans out across ``ctx.options.concurrency`` workers. """ from __future__ import annotations import base64 import logging +import threading import time +from dataclasses import dataclass from typing import Any import requests @@ -38,9 +38,6 @@ logger = logging.getLogger(__name__) -# Mime types the BE's fetch_contents_ide endpoint round-trips losslessly. -# PDF → base64; text/plain + text/csv → utf-8 string. Excel and other -# types return a placeholder/unhandled — must be flagged for manual upload. _BASE64_MIMES: frozenset[str] = frozenset({"application/pdf"}) _TEXT_MIMES: frozenset[str] = frozenset({"text/plain", "text/csv"}) @@ -49,6 +46,15 @@ _RETRY_BACKOFF_BASE_SECONDS = 1.0 +@dataclass +class _FileTask: + src_tool_id: str + tgt_tool_id: str + tool_name: str + file_name: str + src_document_id: str + + class FilesPhase(Phase): name = "files" @@ -61,12 +67,16 @@ def run(self, report: CloneReport) -> PhaseResult: strategy = self.ctx.options.file_strategy logger.info( - "files phase: strategy=%s tools=%d cap=%d bytes", + "files phase: strategy=%s tools=%d cap=%d bytes concurrency=%d", strategy, len(tool_remap), self.ctx.options.max_file_size, + self.ctx.options.concurrency, ) + # Pass 1: build per-file task list sequentially (cheap). + file_tasks: list[_FileTask] = [] + cloned_tools: list[tuple[str, str, str, list[dict[str, Any]]]] = [] for src_tool_id, tgt_tool_id in tool_remap.items(): tool_name = self._lookup_tool_name(tgt_tool_id) or src_tool_id try: @@ -87,13 +97,27 @@ def run(self, report: CloneReport) -> PhaseResult: ) continue - self._clone_tool( + tasks = self._build_tool_tasks( src_tool_id, tgt_tool_id, tool_name, src_docs, report, result ) + file_tasks.extend(tasks) + cloned_tools.append((src_tool_id, tgt_tool_id, tool_name, src_docs)) + + # Pass 2: download + upload each file in parallel. + if file_tasks: + self.parallel_map( + file_tasks, + lambda task, lock: self._clone_one_file(task, report, result, lock), + ) + + # Pass 3: set default doc per tool after all uploads land. + if not self.ctx.options.dry_run and strategy != "skip": + for src_tool_id, tgt_tool_id, tool_name, src_docs in cloned_tools: + self._ensure_default_doc(src_tool_id, tgt_tool_id, tool_name, src_docs) return result - def _clone_tool( + def _build_tool_tasks( self, src_tool_id: str, tgt_tool_id: str, @@ -101,7 +125,7 @@ def _clone_tool( src_docs: list[dict[str, Any]], report: CloneReport, result: PhaseResult, - ) -> None: + ) -> list[_FileTask]: try: tgt_docs = self.ctx.target.list_prompt_documents(tgt_tool_id) except Exception as e: @@ -112,9 +136,10 @@ def _clone_tool( ) result.failed += 1 result.errors.append(f"list target docs {tool_name}: {e}") - return + return [] target_names = {d["document_name"] for d in tgt_docs} + tasks: list[_FileTask] = [] for doc in src_docs: file_name = doc.get("document_name") src_document_id = doc.get("document_id") @@ -145,52 +170,48 @@ def _clone_tool( file_name, ) continue - self._clone_one_file( - src_tool_id, - tgt_tool_id, - tool_name, - file_name, - src_document_id, - report, - result, + tasks.append( + _FileTask( + src_tool_id=src_tool_id, + tgt_tool_id=tgt_tool_id, + tool_name=tool_name, + file_name=file_name, + src_document_id=src_document_id, + ) ) - - if not self.ctx.options.dry_run: - self._ensure_default_doc(src_tool_id, tgt_tool_id, tool_name, src_docs) + return tasks def _clone_one_file( self, - src_tool_id: str, - tgt_tool_id: str, - tool_name: str, - file_name: str, - src_document_id: str, + task: _FileTask, report: CloneReport, result: PhaseResult, + lock: threading.Lock, ) -> None: try: payload = self._with_retry( lambda: self.ctx.source.download_prompt_file( - src_tool_id, src_document_id + task.src_tool_id, task.src_document_id ), - op=f"download {tool_name}/{file_name}", + op=f"download {task.tool_name}/{task.file_name}", ) except Exception as e: logger.exception( "files: download failed tool=%s file=%s: %s", - tool_name, - file_name, + task.tool_name, + task.file_name, e, ) - result.failed += 1 - report.failed_files.append( - { - "tool_id": tgt_tool_id, - "tool_name": tool_name, - "file_name": file_name, - "error": f"download: {e}", - } - ) + with lock: + result.failed += 1 + report.failed_files.append( + { + "tool_id": task.tgt_tool_id, + "tool_name": task.tool_name, + "file_name": task.file_name, + "error": f"download: {e}", + } + ) return mime = (payload or {}).get("mime_type") or "" @@ -198,36 +219,38 @@ def _clone_one_file( if raw is None: logger.warning( "files: unsupported mime tool=%s file=%s mime=%s", - tool_name, - file_name, + task.tool_name, + task.file_name, mime, ) - result.skipped += 1 - report.unsupported_files.append( - { - "tool_id": tgt_tool_id, - "tool_name": tool_name, - "file_name": file_name, - "mime_type": mime, - } - ) + with lock: + result.skipped += 1 + report.unsupported_files.append( + { + "tool_id": task.tgt_tool_id, + "tool_name": task.tool_name, + "file_name": task.file_name, + "mime_type": mime, + } + ) return if len(raw) > self.ctx.options.max_file_size: - result.skipped += 1 - report.oversize_files.append( - { - "tool_id": tgt_tool_id, - "tool_name": tool_name, - "file_name": file_name, - "size_bytes": len(raw), - "cap_bytes": self.ctx.options.max_file_size, - } - ) + with lock: + result.skipped += 1 + report.oversize_files.append( + { + "tool_id": task.tgt_tool_id, + "tool_name": task.tool_name, + "file_name": task.file_name, + "size_bytes": len(raw), + "cap_bytes": self.ctx.options.max_file_size, + } + ) logger.info( "files: oversize tool=%s file=%s size=%d cap=%d", - tool_name, - file_name, + task.tool_name, + task.file_name, len(raw), self.ctx.options.max_file_size, ) @@ -236,42 +259,44 @@ def _clone_one_file( try: self._with_retry( lambda: self.ctx.target.upload_prompt_file( - tgt_tool_id, file_name, raw, mime + task.tgt_tool_id, task.file_name, raw, mime ), - op=f"upload {tool_name}/{file_name}", + op=f"upload {task.tool_name}/{task.file_name}", ) except Exception as e: logger.exception( "files: upload failed tool=%s file=%s: %s", - tool_name, - file_name, + task.tool_name, + task.file_name, e, ) - result.failed += 1 - report.failed_files.append( + with lock: + result.failed += 1 + report.failed_files.append( + { + "tool_id": task.tgt_tool_id, + "tool_name": task.tool_name, + "file_name": task.file_name, + "error": f"upload: {e}", + } + ) + return + + with lock: + result.created += 1 + report.uploaded_files.append( { - "tool_id": tgt_tool_id, - "tool_name": tool_name, - "file_name": file_name, - "error": f"upload: {e}", + "tool_id": task.tgt_tool_id, + "tool_name": task.tool_name, + "file_name": task.file_name, + "size_bytes": len(raw), + "mime_type": mime, } ) - return - - result.created += 1 - report.uploaded_files.append( - { - "tool_id": tgt_tool_id, - "tool_name": tool_name, - "file_name": file_name, - "size_bytes": len(raw), - "mime_type": mime, - } - ) logger.info( "files: uploaded tool=%s file=%s size=%d", - tool_name, - file_name, + task.tool_name, + task.file_name, len(raw), ) @@ -313,7 +338,6 @@ def _decode_payload( if data_field is None: return None if mime in _BASE64_MIMES: - # data_field is base64-encoded bytes (BE wraps with b64encode). if isinstance(data_field, bytes): return base64.b64decode(data_field) return base64.b64decode(data_field.encode()) @@ -389,9 +413,6 @@ def _pick_default_doc_id( tgt_docs: list[dict[str, Any]], tool_name: str, ) -> str | None: - # Try mirroring the source's selection by filename. If source - # GET fails or source has no chosen doc, fall back to the first - # target doc so the FE doesn't render an empty selector. try: src_tool = self.ctx.source.get_custom_tool(src_tool_id) src_output = src_tool.get("output") @@ -428,8 +449,6 @@ def _pick_default_doc_id( return tgt_docs[0].get("document_id") def _lookup_tool_name(self, tgt_tool_id: str) -> str | None: - # Cosmetic helper for logs only — never let a transport hiccup here - # mask a downstream "tool was deleted" or auth failure. try: tools = self.ctx.target.list_custom_tools() except PlatformAPIError as e: diff --git a/src/unstract/clone/phases/pipeline.py b/src/unstract/clone/phases/pipeline.py index 8142128..9892b1c 100644 --- a/src/unstract/clone/phases/pipeline.py +++ b/src/unstract/clone/phases/pipeline.py @@ -14,6 +14,7 @@ from __future__ import annotations import logging +import threading from typing import Any from unstract.clone.exceptions import NameConflictError @@ -61,36 +62,44 @@ def run(self, report: CloneReport) -> PhaseResult: else: logger.info("Found %d source pipeline(s)", len(src_pipelines)) - for src in migratable: - self._clone_one(src, result) + self.parallel_map( + migratable, + lambda src, lock: self._clone_one(src, result, lock), + ) return result - def _clone_one(self, src: dict[str, Any], result: PhaseResult) -> None: + def _clone_one( + self, src: dict[str, Any], result: PhaseResult, lock: threading.Lock + ) -> None: name = src["pipeline_name"] src_id = src["id"] src_wf_id = src.get("workflow") or src.get("workflow_id") if not src_wf_id: logger.warning("source pipeline '%s' has no workflow FK — skipping", name) - result.skipped += 1 + with lock: + result.skipped += 1 return - tgt_wf_id = self.ctx.remap.resolve("workflow", src_wf_id) + with lock: + tgt_wf_id = self.ctx.remap.resolve("workflow", src_wf_id) if not tgt_wf_id: logger.warning( "no workflow remap for pipeline '%s' (src workflow %s) — skipping", name, src_wf_id, ) - result.skipped += 1 + with lock: + result.skipped += 1 return try: existing = self.ctx.target.list_pipelines(name=name) except Exception as e: logger.exception("Failed to GET pipeline %s on target: %s", name, e) - result.failed += 1 - result.errors.append(f"GET {name}: {e}") + with lock: + result.failed += 1 + result.errors.append(f"GET {name}: {e}") return if existing: @@ -99,22 +108,24 @@ def _clone_one(self, src: dict[str, Any], result: PhaseResult) -> None: raise NameConflictError( f"pipeline '{name}' already exists in target as {tgt['id']}" ) - result.adopted += 1 + with lock: + result.adopted += 1 logger.info( "adopted pipeline '%s' src=%s -> tgt=%s", name, src_id, tgt["id"] ) elif self.ctx.options.dry_run: - result.skipped += 1 + with lock: + result.skipped += 1 logger.info("[dry-run] would create pipeline '%s' src=%s", name, src_id) return else: try: - # list serializer can strip fields the create serializer expects. full_src = self.ctx.source.get_pipeline(src_id) except Exception as e: logger.exception("Failed to GET source pipeline %s: %s", name, e) - result.failed += 1 - result.errors.append(f"GET src pipeline {name}: {e}") + with lock: + result.failed += 1 + result.errors.append(f"GET src pipeline {name}: {e}") return remapped = remap_uuids(full_src, self.ctx.remap) payload = build_post_payload(remapped, self._writable) @@ -123,16 +134,19 @@ def _clone_one(self, src: dict[str, Any], result: PhaseResult) -> None: tgt = self.ctx.target.create_pipeline(payload) except Exception as e: logger.exception("Failed to create pipeline %s: %s", name, e) - result.failed += 1 - result.errors.append(f"create {name}: {e}") + with lock: + result.failed += 1 + result.errors.append(f"create {name}: {e}") return - result.created += 1 + with lock: + result.created += 1 logger.info( "created pipeline '%s' src=%s -> tgt=%s", name, src_id, tgt["id"] ) self._warn_if_extra_source_keys(src_id, name) - self.ctx.remap.record("pipeline", src_id, tgt["id"]) + with lock: + self.ctx.remap.record("pipeline", src_id, tgt["id"]) def _warn_if_extra_source_keys(self, src_pipeline_id: str, name: str) -> None: try: diff --git a/src/unstract/clone/phases/tag.py b/src/unstract/clone/phases/tag.py index e6ae31a..9cbca05 100644 --- a/src/unstract/clone/phases/tag.py +++ b/src/unstract/clone/phases/tag.py @@ -11,6 +11,7 @@ from __future__ import annotations import logging +import threading from typing import Any from unstract.clone.exceptions import NameConflictError @@ -43,11 +44,15 @@ def run(self, report: CloneReport) -> PhaseResult: return result logger.info("Found %d tag(s) in source org", len(src_tags)) - for src in src_tags: - self._clone_one(src, result) + self.parallel_map( + src_tags, + lambda src, lock: self._clone_one(src, result, lock), + ) return result - def _clone_one(self, src: dict[str, Any], result: PhaseResult) -> None: + def _clone_one( + self, src: dict[str, Any], result: PhaseResult, lock: threading.Lock + ) -> None: name = src["name"] src_id = src["id"] @@ -55,8 +60,9 @@ def _clone_one(self, src: dict[str, Any], result: PhaseResult) -> None: existing = self.ctx.target.list_tags(name=name) except Exception as e: logger.exception("Failed to GET tag %s on target: %s", name, e) - result.failed += 1 - result.errors.append(f"GET {name}: {e}") + with lock: + result.failed += 1 + result.errors.append(f"GET {name}: {e}") return if existing: @@ -65,10 +71,12 @@ def _clone_one(self, src: dict[str, Any], result: PhaseResult) -> None: raise NameConflictError( f"tag '{name}' already exists in target as {tgt['id']}" ) - result.adopted += 1 + with lock: + result.adopted += 1 logger.info("adopted tag '%s' src=%s -> tgt=%s", name, src_id, tgt["id"]) elif self.ctx.options.dry_run: - result.skipped += 1 + with lock: + result.skipped += 1 logger.info("[dry-run] would create tag '%s' src=%s", name, src_id) return else: @@ -77,10 +85,13 @@ def _clone_one(self, src: dict[str, Any], result: PhaseResult) -> None: tgt = self.ctx.target.create_tag(payload) except Exception as e: logger.exception("Failed to create tag %s: %s", name, e) - result.failed += 1 - result.errors.append(f"create {name}: {e}") + with lock: + result.failed += 1 + result.errors.append(f"create {name}: {e}") return - result.created += 1 + with lock: + result.created += 1 logger.info("created tag '%s' src=%s -> tgt=%s", name, src_id, tgt["id"]) - self.ctx.remap.record("tag", src_id, tgt["id"]) + with lock: + self.ctx.remap.record("tag", src_id, tgt["id"]) diff --git a/src/unstract/clone/phases/tool_instance.py b/src/unstract/clone/phases/tool_instance.py index 67e5ca5..71ef5bb 100644 --- a/src/unstract/clone/phases/tool_instance.py +++ b/src/unstract/clone/phases/tool_instance.py @@ -20,6 +20,7 @@ from __future__ import annotations import logging +import threading from typing import Any from unstract.clone.phases.base import Phase @@ -72,12 +73,20 @@ def run(self, report: CloneReport) -> PhaseResult: logger.info("No workflows in remap; nothing to do for tool_instance phase") return result - for src_wf_id, tgt_wf_id in workflow_remap.items(): - self._clone_workflow_tools(src_wf_id, tgt_wf_id, result) + self.parallel_map( + list(workflow_remap.items()), + lambda pair, lock: self._clone_workflow_tools( + pair[0], pair[1], result, lock + ), + ) return result def _clone_workflow_tools( - self, src_wf_id: str, tgt_wf_id: str, result: PhaseResult + self, + src_wf_id: str, + tgt_wf_id: str, + result: PhaseResult, + lock: threading.Lock, ) -> None: try: src_instances = self.ctx.source.list_tool_instances(workflow_id=src_wf_id) @@ -85,14 +94,14 @@ def _clone_workflow_tools( logger.exception( "Failed to list source tool_instances for wf %s: %s", src_wf_id, e ) - result.failed += 1 - result.errors.append(f"list src tool_instances {src_wf_id}: {e}") + with lock: + result.failed += 1 + result.errors.append(f"list src tool_instances {src_wf_id}: {e}") return if not src_instances: return if len(src_instances) > 1: - # Backend enforces ≤1; warn loudly if invariant breaks on source. logger.warning( "source workflow %s has %d tool_instances (expected ≤1) — migrating first only", src_wf_id, @@ -103,7 +112,8 @@ def _clone_workflow_tools( src_ti_id = src_ti["id"] src_tool_id = src_ti["tool_id"] - tgt_tool_id = self.ctx.remap.resolve("prompt_studio_registry", src_tool_id) + with lock: + tgt_tool_id = self.ctx.remap.resolve("prompt_studio_registry", src_tool_id) if not tgt_tool_id: logger.warning( "skipping tool_instance %s — no registry remap for tool_id %s " @@ -111,7 +121,8 @@ def _clone_workflow_tools( src_ti_id, src_tool_id, ) - result.skipped += 1 + with lock: + result.skipped += 1 return try: @@ -120,14 +131,17 @@ def _clone_workflow_tools( logger.exception( "Failed to list target tool_instances for wf %s: %s", tgt_wf_id, e ) - result.failed += 1 - result.errors.append(f"list tgt tool_instances {tgt_wf_id}: {e}") + with lock: + result.failed += 1 + result.errors.append(f"list tgt tool_instances {tgt_wf_id}: {e}") return if existing: tgt_ti = existing[0] if self.ctx.options.dry_run: - result.skipped += 1 + with lock: + result.skipped += 1 + self.ctx.remap.record("tool_instance", src_ti_id, tgt_ti["id"]) logger.info( "[dry-run] would re-PATCH metadata on adopted tool_instance " "src=%s -> tgt=%s (workflow %s)", @@ -135,9 +149,9 @@ def _clone_workflow_tools( tgt_ti["id"], tgt_wf_id, ) - self.ctx.remap.record("tool_instance", src_ti_id, tgt_ti["id"]) return - result.adopted += 1 + with lock: + result.adopted += 1 logger.info( "adopted tool_instance src=%s -> tgt=%s (workflow %s)", src_ti_id, @@ -145,7 +159,8 @@ def _clone_workflow_tools( tgt_wf_id, ) elif self.ctx.options.dry_run: - result.skipped += 1 + with lock: + result.skipped += 1 logger.info( "[dry-run] would create tool_instance for tgt workflow %s " "(src tool_instance %s)", @@ -162,10 +177,12 @@ def _clone_workflow_tools( logger.exception( "Failed to create tool_instance for wf %s: %s", tgt_wf_id, e ) - result.failed += 1 - result.errors.append(f"create tool_instance {tgt_wf_id}: {e}") + with lock: + result.failed += 1 + result.errors.append(f"create tool_instance {tgt_wf_id}: {e}") return - result.created += 1 + with lock: + result.created += 1 logger.info( "created tool_instance src=%s -> tgt=%s (workflow %s)", src_ti_id, @@ -173,8 +190,6 @@ def _clone_workflow_tools( tgt_wf_id, ) - # PATCH the metadata regardless of created/adopted — keeps tool config - # aligned with source on every run. src_metadata = src_ti.get("metadata") or {} broken = _broken_adapter_keys(src_metadata) if broken: @@ -186,13 +201,13 @@ def _clone_workflow_tools( tgt_ti["id"], broken, ) - result.errors.append( - f"stale adapter refs on src tool_instance {src_ti_id}: {broken}" - ) + with lock: + result.errors.append( + f"stale adapter refs on src tool_instance {src_ti_id}: {broken}" + ) else: - # PATCH overwrites the whole metadata dict, so we must include - # the target's own identity fields or the runtime sees them - # as empty. tool_id IS the prompt_registry_id. + # PATCH overwrites the whole metadata dict — re-stamp target + # identity fields or the runtime sees them as empty. patch_metadata = { **_strip_source_identity(src_metadata), "prompt_registry_id": tgt_tool_id, @@ -206,8 +221,10 @@ def _clone_workflow_tools( logger.exception( "Failed to PATCH tool_instance %s metadata: %s", tgt_ti["id"], e ) - result.failed += 1 - result.errors.append(f"patch metadata {tgt_ti['id']}: {e}") + with lock: + result.failed += 1 + result.errors.append(f"patch metadata {tgt_ti['id']}: {e}") return - self.ctx.remap.record("tool_instance", src_ti_id, tgt_ti["id"]) + with lock: + self.ctx.remap.record("tool_instance", src_ti_id, tgt_ti["id"]) diff --git a/src/unstract/clone/phases/workflow.py b/src/unstract/clone/phases/workflow.py index f8095b9..36d55d4 100644 --- a/src/unstract/clone/phases/workflow.py +++ b/src/unstract/clone/phases/workflow.py @@ -16,6 +16,7 @@ from __future__ import annotations import logging +import threading from typing import Any from unstract.clone.exceptions import NameConflictError @@ -50,11 +51,15 @@ def run(self, report: CloneReport) -> PhaseResult: return result logger.info("Found %d workflow(s) in source org", len(src_workflows)) - for src in src_workflows: - self._clone_one(src, result) + self.parallel_map( + src_workflows, + lambda src, lock: self._clone_one(src, result, lock), + ) return result - def _clone_one(self, src: dict[str, Any], result: PhaseResult) -> None: + def _clone_one( + self, src: dict[str, Any], result: PhaseResult, lock: threading.Lock + ) -> None: name = src["workflow_name"] src_id = src["id"] @@ -62,8 +67,9 @@ def _clone_one(self, src: dict[str, Any], result: PhaseResult) -> None: existing = self.ctx.target.list_workflows(name=name) except Exception as e: logger.exception("Failed to GET workflow %s on target: %s", name, e) - result.failed += 1 - result.errors.append(f"GET {name}: {e}") + with lock: + result.failed += 1 + result.errors.append(f"GET {name}: {e}") return if existing: @@ -72,10 +78,14 @@ def _clone_one(self, src: dict[str, Any], result: PhaseResult) -> None: raise NameConflictError( f"workflow '{name}' already exists in target as {tgt['id']}" ) - result.adopted += 1 - logger.info("adopted workflow '%s' src=%s -> tgt=%s", name, src_id, tgt["id"]) + with lock: + result.adopted += 1 + logger.info( + "adopted workflow '%s' src=%s -> tgt=%s", name, src_id, tgt["id"] + ) elif self.ctx.options.dry_run: - result.skipped += 1 + with lock: + result.skipped += 1 logger.info("[dry-run] would create workflow '%s' src=%s", name, src_id) return else: @@ -85,10 +95,15 @@ def _clone_one(self, src: dict[str, Any], result: PhaseResult) -> None: tgt = self.ctx.target.create_workflow(payload) except Exception as e: logger.exception("Failed to create workflow %s: %s", name, e) - result.failed += 1 - result.errors.append(f"create {name}: {e}") + with lock: + result.failed += 1 + result.errors.append(f"create {name}: {e}") return - result.created += 1 - logger.info("created workflow '%s' src=%s -> tgt=%s", name, src_id, tgt["id"]) - - self.ctx.remap.record("workflow", src_id, tgt["id"]) + with lock: + result.created += 1 + logger.info( + "created workflow '%s' src=%s -> tgt=%s", name, src_id, tgt["id"] + ) + + with lock: + self.ctx.remap.record("workflow", src_id, tgt["id"]) diff --git a/src/unstract/clone/phases/workflow_endpoint.py b/src/unstract/clone/phases/workflow_endpoint.py index 8c235d0..a9ffa7a 100644 --- a/src/unstract/clone/phases/workflow_endpoint.py +++ b/src/unstract/clone/phases/workflow_endpoint.py @@ -20,6 +20,7 @@ from __future__ import annotations import logging +import threading from typing import Any from unstract.clone.phases.base import Phase @@ -51,12 +52,20 @@ def run(self, report: CloneReport) -> PhaseResult: ) return result - for src_wf_id, tgt_wf_id in workflow_remap.items(): - self._clone_workflow_endpoints(src_wf_id, tgt_wf_id, result) + self.parallel_map( + list(workflow_remap.items()), + lambda pair, lock: self._clone_workflow_endpoints( + pair[0], pair[1], result, lock + ), + ) return result def _clone_workflow_endpoints( - self, src_wf_id: str, tgt_wf_id: str, result: PhaseResult + self, + src_wf_id: str, + tgt_wf_id: str, + result: PhaseResult, + lock: threading.Lock, ) -> None: try: src_endpoints = self.ctx.source.list_workflow_endpoints( @@ -66,8 +75,9 @@ def _clone_workflow_endpoints( logger.exception( "Failed to list source endpoints for wf %s: %s", src_wf_id, e ) - result.failed += 1 - result.errors.append(f"list src endpoints {src_wf_id}: {e}") + with lock: + result.failed += 1 + result.errors.append(f"list src endpoints {src_wf_id}: {e}") return try: @@ -78,8 +88,9 @@ def _clone_workflow_endpoints( logger.exception( "Failed to list target endpoints for wf %s: %s", tgt_wf_id, e ) - result.failed += 1 - result.errors.append(f"list tgt endpoints {tgt_wf_id}: {e}") + with lock: + result.failed += 1 + result.errors.append(f"list tgt endpoints {tgt_wf_id}: {e}") return tgt_by_type = {ep["endpoint_type"]: ep for ep in tgt_endpoints} @@ -88,28 +99,34 @@ def _clone_workflow_endpoints( etype = src_ep["endpoint_type"] tgt_ep = tgt_by_type.get(etype) if tgt_ep is None: - # Target should have auto-created this; missing means the - # workflow create flow failed earlier — surface loudly. logger.warning( "target workflow %s missing %s endpoint — skipping", tgt_wf_id, etype, ) - result.failed += 1 - result.errors.append(f"missing tgt {etype} endpoint for wf {tgt_wf_id}") + with lock: + result.failed += 1 + result.errors.append( + f"missing tgt {etype} endpoint for wf {tgt_wf_id}" + ) continue - self._patch_endpoint(src_ep, tgt_ep, result) + self._patch_endpoint(src_ep, tgt_ep, result, lock) def _patch_endpoint( - self, src_ep: dict[str, Any], tgt_ep: dict[str, Any], result: PhaseResult + self, + src_ep: dict[str, Any], + tgt_ep: dict[str, Any], + result: PhaseResult, + lock: threading.Lock, ) -> None: src_ep_id = src_ep["id"] tgt_ep_id = tgt_ep["id"] etype = src_ep["endpoint_type"] if self.ctx.options.dry_run: - result.skipped += 1 + with lock: + result.skipped += 1 logger.info( "[dry-run] would PATCH %s endpoint src=%s -> tgt=%s", etype, @@ -121,12 +138,9 @@ def _patch_endpoint( src_conn_id = _extract_connector_id(src_ep) tgt_conn_id: str | None = None if src_conn_id: - tgt_conn_id = self.ctx.remap.resolve("connector", src_conn_id) + with lock: + tgt_conn_id = self.ctx.remap.resolve("connector", src_conn_id) if not tgt_conn_id: - # Source had a connector but it never made it through the - # connector phase (e.g. redacted secrets, skipped row). - # Patching the endpoint with connector=None would silently - # detach it on target; skip + flag so the operator notices. logger.warning( "skipping %s endpoint src=%s tgt=%s — source connector %s " "has no target remap; would silently unset connector", @@ -135,16 +149,14 @@ def _patch_endpoint( tgt_ep_id, src_conn_id, ) - result.skipped += 1 - result.errors.append( - f"unmapped connector on {etype} endpoint {src_ep_id}: " - f"src_connector={src_conn_id}" - ) + with lock: + result.skipped += 1 + result.errors.append( + f"unmapped connector on {etype} endpoint {src_ep_id}: " + f"src_connector={src_conn_id}" + ) return - # connection_type is a required enum on the backend; pass through - # source's value (incl. None) verbatim so the backend's validation - # surfaces the real problem rather than us papering over with "". payload: dict[str, Any] = { "configuration": remap_uuids( src_ep.get("configuration") or {}, self.ctx.remap @@ -161,11 +173,14 @@ def _patch_endpoint( logger.exception( "Failed to PATCH %s endpoint tgt=%s: %s", etype, tgt_ep_id, e ) - result.failed += 1 - result.errors.append(f"patch {etype} {tgt_ep_id}: {e}") + with lock: + result.failed += 1 + result.errors.append(f"patch {etype} {tgt_ep_id}: {e}") return - result.created += 1 + with lock: + result.created += 1 + self.ctx.remap.record("workflow_endpoint", src_ep_id, tgt_ep_id) logger.info( "patched %s endpoint src=%s -> tgt=%s (connector %s)", etype, @@ -173,4 +188,3 @@ def _patch_endpoint( tgt_ep_id, tgt_conn_id, ) - self.ctx.remap.record("workflow_endpoint", src_ep_id, tgt_ep_id) diff --git a/src/unstract/clone/report.py b/src/unstract/clone/report.py index 9b1b9ce..296b8b6 100644 --- a/src/unstract/clone/report.py +++ b/src/unstract/clone/report.py @@ -22,6 +22,7 @@ class PhaseResult: skipped: int = 0 failed: int = 0 errors: list[str] = field(default_factory=list) + duration_s: float = 0.0 @dataclass @@ -41,6 +42,7 @@ class CloneReport: remap_snapshot: dict[str, dict[str, str]] = field(default_factory=dict) aborted: bool = False abort_reason: str | None = None + total_duration_s: float = 0.0 # Files-phase artifacts. Each entry carries enough context for an # operator to act on it without cross-referencing the run log. uploaded_files: list[dict[str, Any]] = field(default_factory=list) @@ -70,11 +72,13 @@ def render(self) -> str: buf = StringIO() # force_terminal so ANSI codes survive the StringIO capture; the # caller decides whether to strip them when printing to a non-tty. - console = Console(file=buf, force_terminal=True, color_system="truecolor", width=100) + console = Console( + file=buf, force_terminal=True, color_system="truecolor", width=100 + ) self._render_endpoints(console.print) table = Table(title="Clone Report", header_style="bold cyan") table.add_column("Phase", style="bold", justify="left") - for col in ("Created", "Adopted", "Skipped", "Failed"): + for col in ("Created", "Adopted", "Skipped", "Failed", "Time"): table.add_column(col, justify="right") totals = {"created": 0, "adopted": 0, "skipped": 0, "failed": 0} @@ -86,6 +90,7 @@ def render(self) -> str: self._fmt_count(p.adopted, "green"), self._fmt_count(p.skipped, "yellow"), self._fmt_count(p.failed, "red"), + self._fmt_duration(p.duration_s), ) for k in totals: totals[k] += getattr(p, k) @@ -97,10 +102,13 @@ def render(self) -> str: self._fmt_count(totals["adopted"], "green", bold=True), self._fmt_count(totals["skipped"], "yellow", bold=True), self._fmt_count(totals["failed"], "red", bold=True), + self._fmt_duration(self.total_duration_s, bold=True), ) console.print(table) if self.skipped_phases: - console.print(f"[dim]Skipped phases:[/dim] {', '.join(self.skipped_phases)}") + console.print( + f"[dim]Skipped phases:[/dim] {', '.join(self.skipped_phases)}" + ) self._render_files_sections(console) self._render_remap_summary(console_print=console.print) if self.aborted: @@ -122,15 +130,42 @@ def _fmt_count(value: int, color: str, bold: bool = False) -> str: style = f"bold {color}" if bold else color return f"[{style}]{value}[/{style}]" + @staticmethod + def _fmt_duration(seconds: float, bold: bool = False) -> str: + if seconds <= 0: + return "[dim]—[/dim]" + if seconds < 60: + text = f"{seconds:.1f}s" + else: + mins, secs = divmod(seconds, 60) + text = f"{int(mins)}m{secs:.0f}s" + return f"[bold]{text}[/bold]" if bold else text + + @staticmethod + def _fmt_duration_plain(seconds: float) -> str: + if seconds <= 0: + return "—" + if seconds < 60: + return f"{seconds:.1f}s" + mins, secs = divmod(seconds, 60) + return f"{int(mins)}m{secs:.0f}s" + def _render_plain(self) -> str: lines = ["Clone Report", "=" * 60] self._render_endpoints(lines.append) - header = f"{'Phase':<24}{'Created':>10}{'Adopted':>10}{'Skipped':>10}{'Failed':>10}" + header = ( + f"{'Phase':<24}{'Created':>10}{'Adopted':>10}" + f"{'Skipped':>10}{'Failed':>10}{'Time':>10}" + ) lines.append(header) for p in self.phases: lines.append( - f"{p.name:<24}{p.created:>10}{p.adopted:>10}{p.skipped:>10}{p.failed:>10}" + f"{p.name:<24}{p.created:>10}{p.adopted:>10}" + f"{p.skipped:>10}{p.failed:>10}{self._fmt_duration_plain(p.duration_s):>10}" ) + lines.append( + f"{'TOTAL':<64}{self._fmt_duration_plain(self.total_duration_s):>10}" + ) if self.skipped_phases: lines.append(f"Skipped phases: {', '.join(self.skipped_phases)}") lines.extend(self._files_sections_plain()) @@ -142,12 +177,18 @@ def _render_plain(self) -> str: def as_dict(self) -> dict[str, Any]: return { "source": ( - {"base_url": self.source.base_url, "organization_id": self.source.organization_id} + { + "base_url": self.source.base_url, + "organization_id": self.source.organization_id, + } if self.source else None ), "target": ( - {"base_url": self.target.base_url, "organization_id": self.target.organization_id} + { + "base_url": self.target.base_url, + "organization_id": self.target.organization_id, + } if self.target else None ), @@ -159,6 +200,7 @@ def as_dict(self) -> dict[str, Any]: "skipped": p.skipped, "failed": p.failed, "errors": list(p.errors), + "duration_s": p.duration_s, } for p in self.phases ], @@ -166,6 +208,7 @@ def as_dict(self) -> dict[str, Any]: "remap_snapshot": self.remap_snapshot, "aborted": self.aborted, "abort_reason": self.abort_reason, + "total_duration_s": self.total_duration_s, "uploaded_files": list(self.uploaded_files), "skipped_files": list(self.skipped_files), "oversize_files": list(self.oversize_files), @@ -208,9 +251,7 @@ def _render_remap_summary(self, console_print: Any) -> None: def _render_files_sections(self, console: Any) -> None: if self.uploaded_files: - console.print( - f"[green]Files uploaded:[/green] {len(self.uploaded_files)}" - ) + console.print(f"[green]Files uploaded:[/green] {len(self.uploaded_files)}") for header, rows in ( ("Oversize files (manual upload required)", self.oversize_files), ("Unsupported mime files (manual upload required)", self.unsupported_files), diff --git a/tests/clone/test_phase_concurrency.py b/tests/clone/test_phase_concurrency.py new file mode 100644 index 0000000..0179f1b --- /dev/null +++ b/tests/clone/test_phase_concurrency.py @@ -0,0 +1,291 @@ +"""Thread-safety checks for ``Phase.parallel_map``. + +Coverage: +- Many-item fan-out produces exact counts + remap entries with no loss. +- Sequential path (``concurrency=1``) skips the thread pool entirely + while preserving identical behaviour. +- ``CloneError`` raised inside a worker propagates out of ``parallel_map`` + so the orchestrator's abort handling engages. +- A non-``CloneError`` exception inside a worker still propagates. + +We use a fake client that holds a lock around its own mutable state and +injects a small sleep per HTTP call to force real interleaving between +workers, then assert the phase's lock-guarded code keeps counters and +the remap table consistent. +""" + +from __future__ import annotations + +import threading +import time + +import pytest + +from unstract.clone.context import CloneContext, CloneOptions, RemapTable +from unstract.clone.exceptions import NameConflictError +from unstract.clone.phases.adapter import AdapterPhase +from unstract.clone.phases.tag import TagPhase +from unstract.clone.report import CloneReport + + +class _ThreadSafeAdapterClient: + """Adapter FakeClient with a lock around mutable state + per-call sleep + so workers actually interleave under ThreadPoolExecutor. + """ + + POST_SCHEMA = frozenset( + { + "adapter_id", + "adapter_name", + "adapter_type", + "adapter_metadata", + "description", + } + ) + + def __init__(self, adapters=None, sleep_seconds: float = 0.005): + self._adapters: list[dict] = list(adapters or []) + self.posts: list[dict] = [] + self._next_id = 1 + self._lock = threading.Lock() + self._sleep = sleep_seconds + + def get_post_schema(self, entity_path): + return self.POST_SCHEMA + + def list_adapters(self, *, name=None, adapter_type=None): + time.sleep(self._sleep) + with self._lock: + snap = list(self._adapters) + result = snap + if name is not None: + result = [a for a in result if a["adapter_name"] == name] + if adapter_type is not None: + result = [a for a in result if a["adapter_type"] == adapter_type] + return [{k: v for k, v in a.items() if k != "adapter_metadata"} for a in result] + + def get_adapter(self, adapter_pk): + time.sleep(self._sleep) + with self._lock: + for a in self._adapters: + if a["id"] == adapter_pk: + return dict(a) + raise KeyError(adapter_pk) + + def create_adapter(self, payload): + time.sleep(self._sleep) + with self._lock: + new = dict(payload) + new["id"] = f"tgt-{self._next_id:08d}-0000-0000-0000-000000000000" + self._next_id += 1 + self._adapters.append(new) + self.posts.append(new) + return new + + +def _src_adapter(id_, name, atype="LLM"): + return { + "id": id_, + "adapter_id": "openai-llm-v2", + "adapter_name": name, + "adapter_type": atype, + "adapter_metadata": {"api_key": "sk-secret", "model": "gpt-4"}, + "description": f"{name} desc", + } + + +def _ctx(source, target, **opt_overrides): + return CloneContext( + source=source, + target=target, + options=CloneOptions(**opt_overrides), + remap=RemapTable(), + ) + + +def test_parallel_map_preserves_counts_with_many_items(): + items = 50 + src = _ThreadSafeAdapterClient( + [_src_adapter(f"src-{i:03d}", f"adapter-{i:03d}") for i in range(items)] + ) + tgt = _ThreadSafeAdapterClient() + ctx = _ctx(src, tgt, concurrency=8) + report = CloneReport() + + result = AdapterPhase(ctx).run(report) + + assert result.created == items + assert result.adopted == 0 + assert result.skipped == 0 + assert result.failed == 0 + assert len(tgt.posts) == items + remap = ctx.remap.snapshot().get("adapter", {}) + assert len(remap) == items + # Every source id should be mapped to a fresh target id. + assert set(remap.keys()) == {f"src-{i:03d}" for i in range(items)} + assert len(set(remap.values())) == items + + +def test_concurrency_one_runs_sequentially_with_no_executor(monkeypatch): + """With concurrency=1 we should never hit ThreadPoolExecutor.""" + sentinel = {"executor_used": False} + + import unstract.clone.phases.base as base_mod + + original = base_mod.ThreadPoolExecutor + + class _Forbidden: + def __init__(self, *a, **kw): + sentinel["executor_used"] = True + raise AssertionError("ThreadPoolExecutor must not be used at concurrency=1") + + monkeypatch.setattr(base_mod, "ThreadPoolExecutor", _Forbidden) + src = _ThreadSafeAdapterClient( + [_src_adapter(f"src-{i}", f"a-{i}") for i in range(5)] + ) + tgt = _ThreadSafeAdapterClient() + ctx = _ctx(src, tgt, concurrency=1) + report = CloneReport() + + result = AdapterPhase(ctx).run(report) + + assert result.created == 5 + assert sentinel["executor_used"] is False + # restore for any other tests in same module (monkeypatch undoes on teardown). + base_mod.ThreadPoolExecutor = original # noqa: F841 + + +class _AbortingAdapterClient(_ThreadSafeAdapterClient): + """As parent, but ``list_adapters`` claims the named adapter already + exists on target — used to trigger NameConflictError when the phase + is run with ``on_name_conflict='abort'``.""" + + def list_adapters(self, *, name=None, adapter_type=None): + time.sleep(self._sleep) + return [ + { + "id": "tgt-existing-0001", + "adapter_name": name or "x", + "adapter_type": adapter_type or "LLM", + } + ] + + +def test_clone_error_in_worker_propagates_under_concurrency(): + src = _ThreadSafeAdapterClient( + [_src_adapter(f"src-{i}", f"clash-{i}") for i in range(10)] + ) + tgt = _AbortingAdapterClient() + ctx = _ctx(src, tgt, concurrency=4, on_name_conflict="abort") + report = CloneReport() + + with pytest.raises(NameConflictError): + AdapterPhase(ctx).run(report) + + +class _UnexpectedAdapterClient(_ThreadSafeAdapterClient): + """One of the GETs blows up with a non-Clone RuntimeError.""" + + def __init__(self, *a, fail_on_name: str, **kw): + super().__init__(*a, **kw) + self._fail_on_name = fail_on_name + + def get_adapter(self, adapter_pk): + snap = super().get_adapter(adapter_pk) + if snap["adapter_name"] == self._fail_on_name: + raise RuntimeError("transport boom") + return snap + + +def test_non_clone_exception_recorded_as_failed_not_raised(): + """Workers convert non-Clone errors into ``result.failed`` counts; + they don't escape the phase. (CloneError is the abort signal — + arbitrary exceptions are per-item failures.)""" + src = _UnexpectedAdapterClient( + adapters=[_src_adapter(f"src-{i}", f"adapter-{i}") for i in range(10)], + fail_on_name="adapter-3", + ) + tgt = _ThreadSafeAdapterClient() + ctx = _ctx(src, tgt, concurrency=4) + report = CloneReport() + + result = AdapterPhase(ctx).run(report) + + assert result.failed == 1 + # The other 9 still created successfully. + assert result.created == 9 + assert len(tgt.posts) == 9 + + +class _TagClient: + """Minimal tag fake with thread-safe state + per-call sleep.""" + + POST_SCHEMA = frozenset({"name", "description"}) + + def __init__(self, tags=None, sleep_seconds: float = 0.005): + self._tags: list[dict] = list(tags or []) + self.posts: list[dict] = [] + self._next_id = 1 + self._lock = threading.Lock() + self._sleep = sleep_seconds + + def get_post_schema(self, entity_path): + return self.POST_SCHEMA + + def list_tags(self, *, name=None): + time.sleep(self._sleep) + with self._lock: + snap = list(self._tags) + if name is not None: + snap = [t for t in snap if t["name"] == name] + return snap + + def create_tag(self, payload): + time.sleep(self._sleep) + with self._lock: + new = dict(payload) + new["id"] = f"tag-tgt-{self._next_id:04d}" + self._next_id += 1 + self._tags.append(new) + self.posts.append(new) + return new + + +def test_tag_phase_parallel_remap_table_consistent(): + """Distinct phase exercising the same parallel_map path — ensures the + helper isn't accidentally adapter-specific. + """ + src = _TagClient( + [{"id": f"tag-src-{i}", "name": f"tag-{i:03d}"} for i in range(30)] + ) + tgt = _TagClient() + ctx = _ctx(src, tgt, concurrency=8) + report = CloneReport() + + result = TagPhase(ctx).run(report) + + assert result.created == 30 + assert result.failed == 0 + remap = ctx.remap.snapshot().get("tag", {}) + assert len(remap) == 30 + # remap value uniqueness — no two source tags mapped to the same target id. + assert len(set(remap.values())) == 30 + + +def test_parallel_map_empty_input_no_executor(monkeypatch): + """No items → no thread pool, no work.""" + import unstract.clone.phases.base as base_mod + + class _Forbidden: + def __init__(self, *a, **kw): + raise AssertionError("Should not create pool for empty input") + + monkeypatch.setattr(base_mod, "ThreadPoolExecutor", _Forbidden) + src = _ThreadSafeAdapterClient([]) + tgt = _ThreadSafeAdapterClient() + ctx = _ctx(src, tgt, concurrency=8) + report = CloneReport() + + result = AdapterPhase(ctx).run(report) + assert result.created == 0 + assert result.adopted == 0 From 149bcec2a8878e9ddce9e572d2bfe18b4ec1fa59 Mon Sep 17 00:00:00 2001 From: Chandrasekharan M Date: Wed, 27 May 2026 09:18:10 +0530 Subject: [PATCH 25/25] fix(clone): bump result.skipped when tool_instance metadata PATCH is skipped The broken-adapter-refs branch logged a warning and appended an error entry but didn't bump any counter, so a degraded clone (tool_instance row landed with backend defaults, adapters silently unbound) reported as 'Completed successfully' with exit 0. --- src/unstract/clone/phases/tool_instance.py | 1 + tests/clone/test_tool_instance_phase.py | 21 +++++++++++++++++++++ 2 files changed, 22 insertions(+) diff --git a/src/unstract/clone/phases/tool_instance.py b/src/unstract/clone/phases/tool_instance.py index 71ef5bb..293d206 100644 --- a/src/unstract/clone/phases/tool_instance.py +++ b/src/unstract/clone/phases/tool_instance.py @@ -202,6 +202,7 @@ def _clone_workflow_tools( broken, ) with lock: + result.skipped += 1 result.errors.append( f"stale adapter refs on src tool_instance {src_ti_id}: {broken}" ) diff --git a/tests/clone/test_tool_instance_phase.py b/tests/clone/test_tool_instance_phase.py index 61e59a1..180c285 100644 --- a/tests/clone/test_tool_instance_phase.py +++ b/tests/clone/test_tool_instance_phase.py @@ -186,6 +186,27 @@ def test_no_op_when_no_workflows_in_remap(): assert tgt.create_calls == [] +def test_broken_adapter_refs_bumps_skipped_and_records_error(): + src = FakeClient() + src.instances[SRC_WF] = [ + _src_ti( + "src-ti-1", + SRC_WF, + SRC_REG, + {"llm": "[DELETED ADAPTER] My OpenAI", "embedding": "MyEmb"}, + ) + ] + tgt = FakeClient() + ctx = _ctx(src, tgt, remap=_seed_remap()) + + result = ToolInstancePhase(ctx).run(CloneReport()) + + assert result.created == 1 + assert result.skipped == 1 + assert tgt.patch_calls == [] + assert any("stale adapter refs" in e for e in result.errors) + + def test_dry_run_does_not_create_or_patch(): src = FakeClient() src.instances[SRC_WF] = [_src_ti("src-ti-1", SRC_WF, SRC_REG, {"x": 1})]