diff --git a/pyproject.toml b/pyproject.toml index c45a8a9..a208e3a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,6 +25,15 @@ classifiers = [ "Topic :: Software Development :: Libraries :: Python Modules", ] +[project.optional-dependencies] +clone = [ + "click>=8.1", + "rich>=13.7", +] + +[project.scripts] +unstract-clone = "unstract.clone.cli:main" + [build-system] requires = ["hatchling"] build-backend = "hatchling.build" diff --git a/src/unstract/clone/README.md b/src/unstract/clone/README.md new file mode 100644 index 0000000..c155641 --- /dev/null +++ b/src/unstract/clone/README.md @@ -0,0 +1,55 @@ +# Cloning Organizations + +> [!NOTE] +> **Users are not cloned.** Two reasons: +> - The same user may not need access in every environment. +> - The same user may hold different roles across environments. +> +> **Groups _will_ be cloned** (upcoming — not yet implemented). Once available, an admin can add the right users to each group per environment. + +Clone an Unstract organization's configured resources into another organization (same deployment or different). Useful for environment promotion (DEV → QA → PROD) and for spinning up a fresh org from a known-good baseline. + +Cloned resources: adapters, connectors, custom tools, prompts, profiles, workflows, tool instances, workflow endpoints, tags, API deployments, pipelines, and Prompt Studio document files. The source org is left untouched. + +> **Full documentation, behavior notes, CLI reference, and sample report:** +> https://docs.unstract.com/unstract/unstract_platform/api_documentation/versions/cloning-orgs/ + +## Install + +From a clone of this repository: + +```bash +uv sync --all-extras +``` + +This pulls in the `clone` extra (`click`, `rich`) needed by the CLI. + +## Quickstart + +```bash +UNSTRACT_SRC_PLATFORM_KEY=src_pk_... \ +UNSTRACT_TGT_PLATFORM_KEY=tgt_pk_... \ +uv run python -m unstract.clone clone \ + --source-url https://source.example.com \ + --source-org my-source-org \ + --target-url https://target.example.com \ + --target-org my-target-org +``` + +Both keys must be **org admin Platform API keys**. + +> [!WARNING] +> Both keys grant broad access. Run from a trusted machine and rotate both keys after the clone completes. + +## Re-runs are safe + +If a phase fails partway, fix the cause and re-run the same command. Resources already on the target are detected by name and reused. There is no `--resume-from` flag — the target is the state. + +## Files + +The Prompt Studio document corpus is the only resource type with bytes on disk. Default cap per file is 25 MB; oversize files are reported for manual re-upload. Use `--skip-files` to skip bytes entirely (document records are still created). + +> [!WARNING] +> Run clones during low-activity windows. Concurrent uploads to the source org during a clone can create duplicate file records on the target. + +See the [public docs](https://docs.unstract.com/unstract/unstract_platform/api_documentation/versions/cloning-orgs/) for the full flag list, behavioral notes, and the format of the end-of-run report. diff --git a/src/unstract/clone/__init__.py b/src/unstract/clone/__init__.py new file mode 100644 index 0000000..c36300b --- /dev/null +++ b/src/unstract/clone/__init__.py @@ -0,0 +1,25 @@ +"""Cloning organizations over the Platform API. + +Migrates configured resources (adapters, connectors, custom tools, workflows, +etc.) from one Unstract org to another using two admin-issued Platform API +keys. The target deployment is the persistent state — re-runs reconcile +against existing target rows by natural key. +""" + +from unstract.clone.context import ( + CloneContext, + CloneOptions, + OrgEndpoint, + RemapTable, +) +from unstract.clone.orchestrator import clone +from unstract.clone.report import CloneReport + +__all__ = [ + "CloneContext", + "CloneOptions", + "CloneReport", + "OrgEndpoint", + "RemapTable", + "clone", +] diff --git a/src/unstract/clone/__main__.py b/src/unstract/clone/__main__.py new file mode 100644 index 0000000..2eef45c --- /dev/null +++ b/src/unstract/clone/__main__.py @@ -0,0 +1,6 @@ +"""Entry point: ``python -m unstract.clone``.""" + +from unstract.clone.cli import main + +if __name__ == "__main__": + main() diff --git a/src/unstract/clone/cli.py b/src/unstract/clone/cli.py new file mode 100644 index 0000000..d2ed358 --- /dev/null +++ b/src/unstract/clone/cli.py @@ -0,0 +1,212 @@ +"""Click-based CLI for ``unstract.clone``. + +Single ``clone`` command. Platform keys can be passed via flags +(``--source-key`` / ``--target-key``) or env vars +(``UNSTRACT_SRC_PLATFORM_KEY`` / ``UNSTRACT_TGT_PLATFORM_KEY``) — env vars +are preferred so the key never lands in shell history. +""" + +from __future__ import annotations + +import logging +import re +import sys +from typing import Any + +import click + +from unstract.clone.context import ( + DEFAULT_CONCURRENCY, + DEFAULT_MAX_FILE_SIZE, + CloneOptions, + OrgEndpoint, +) +from unstract.clone.exceptions import CloneError +from unstract.clone.orchestrator import clone as run_clone + +_SIZE_UNITS: dict[str, int] = { + "B": 1, + "K": 1024, + "KB": 1024, + "M": 1024 * 1024, + "MB": 1024 * 1024, + "G": 1024 * 1024 * 1024, + "GB": 1024 * 1024 * 1024, +} +_SIZE_RE = re.compile(r"^\s*(\d+(?:\.\d+)?)\s*([A-Za-z]*)\s*$") + + +def _parse_size(value: str) -> int: + """Accept ``25``, ``25MB``, ``1.5GB`` etc. Returns bytes.""" + m = _SIZE_RE.match(value) + if not m: + raise click.BadParameter(f"can't parse size '{value}'") + num, unit = m.group(1), m.group(2).upper() or "B" + if unit not in _SIZE_UNITS: + raise click.BadParameter( + f"unknown size unit '{unit}'; use one of {sorted(_SIZE_UNITS)}" + ) + return int(float(num) * _SIZE_UNITS[unit]) + + +def _configure_logging(verbose: bool) -> None: + level = logging.DEBUG if verbose else logging.INFO + logging.basicConfig( + level=level, + format="%(asctime)s %(levelname)-7s %(name)s: %(message)s", + datefmt="%H:%M:%S", + ) + + +def _split_csv(value: str | None) -> tuple[str, ...] | None: + if not value: + return None + return tuple(p.strip() for p in value.split(",") if p.strip()) + + +@click.group() +def cli() -> None: + """Cloning organizations over the Platform API.""" + + +@cli.command("clone") +@click.option("--source-url", required=True, help="Base URL of the source deployment") +@click.option( + "--source-org", required=True, help="Source organization_id (slug in the URL path)" +) +@click.option( + "--source-key", + envvar="UNSTRACT_SRC_PLATFORM_KEY", + required=True, + help="Source admin's Platform API key (or env UNSTRACT_SRC_PLATFORM_KEY)", +) +@click.option("--target-url", required=True, help="Base URL of the target deployment") +@click.option( + "--target-org", required=True, help="Target organization_id (slug in the URL path)" +) +@click.option( + "--target-key", + envvar="UNSTRACT_TGT_PLATFORM_KEY", + required=True, + help="Target admin's Platform API key (or env UNSTRACT_TGT_PLATFORM_KEY)", +) +@click.option( + "--dry-run", is_flag=True, help="Plan only — do not POST anything to target" +) +@click.option( + "--include", + default=None, + help="Comma-separated phase names to include (default: all)", +) +@click.option( + "--exclude", + default=None, + help="Comma-separated phase names to exclude", +) +@click.option( + "--on-name-conflict", + type=click.Choice(["adopt", "abort"]), + default="adopt", + show_default=True, + help="What to do when a like-named entity exists in target", +) +@click.option( + "--api-prefix", + default="api/v1", + show_default=True, + help="Backend URL prefix (matches deployment's PATH_PREFIX env)", +) +@click.option( + "--file-strategy", + type=click.Choice(["platform_api", "skip"]), + default="platform_api", + show_default=True, + help="How to move Prompt Studio document files. 'skip' = metadata only.", +) +@click.option( + "--max-file-size", + default="25MB", + show_default=True, + help="Per-file cap for the files phase. Oversize → reported, not aborted.", +) +@click.option( + "--skip-files", + is_flag=True, + help="Alias for --file-strategy=skip.", +) +@click.option( + "--concurrency", + type=click.IntRange(min=1, max=32), + default=DEFAULT_CONCURRENCY, + show_default=True, + help="Per-phase worker count. 1 = strictly sequential.", +) +@click.option("-v", "--verbose", is_flag=True, help="Debug logging") +def clone_cmd( + source_url: str, + source_org: str, + source_key: str, + target_url: str, + target_org: str, + target_key: str, + dry_run: bool, + include: str | None, + exclude: str | None, + on_name_conflict: str, + api_prefix: str, + file_strategy: str, + max_file_size: str, + skip_files: bool, + concurrency: int, + verbose: bool, +) -> None: + """Clone configured resources from one org to another.""" + _configure_logging(verbose) + + effective_strategy = "skip" if skip_files else file_strategy + try: + cap_bytes = _parse_size(max_file_size) + except click.BadParameter as e: + raise click.UsageError(str(e)) from e + + options = CloneOptions( + dry_run=dry_run, + include=_split_csv(include), + exclude=_split_csv(exclude) or (), + on_name_conflict=on_name_conflict, + verbose=verbose, + file_strategy=effective_strategy, + max_file_size=cap_bytes if cap_bytes is not None else DEFAULT_MAX_FILE_SIZE, + concurrency=concurrency, + ) + + source = OrgEndpoint( + base_url=source_url, + organization_id=source_org, + platform_key=source_key, + api_path_prefix=api_prefix, + ) + target = OrgEndpoint( + base_url=target_url, + organization_id=target_org, + platform_key=target_key, + api_path_prefix=api_prefix, + ) + + try: + report = run_clone(source, target, options) + except CloneError as e: + click.echo(f"Clone failed: {e}", err=True) + sys.exit(2) + + click.echo(report.render()) + if report.aborted or any(p.failed for p in report.phases): + sys.exit(1) + + +def main(argv: list[str] | None = None) -> Any: + return cli(args=argv, standalone_mode=True) + + +if __name__ == "__main__": + main() diff --git a/src/unstract/clone/client.py b/src/unstract/clone/client.py new file mode 100644 index 0000000..ad873da --- /dev/null +++ b/src/unstract/clone/client.py @@ -0,0 +1,502 @@ +"""Thin Platform API client for the clone subpackage. + +One ``PlatformClient`` instance per ``OrgEndpoint``. Methods are entity- +scoped (``list_adapters``, ``create_adapter``, ...) so call sites in phases +read like business logic, not HTTP plumbing. + +URL shape: ``{base_url}/{api_path_prefix}/unstract/{organization_id}//`` +Auth: ``Authorization: Bearer ``. +""" + +from __future__ import annotations + +import json as json_lib +import logging +from typing import Any + +import requests + +from unstract.clone.context import OrgEndpoint +from unstract.clone.exceptions import PlatformAPIError + +logger = logging.getLogger(__name__) + +DEFAULT_TIMEOUT = 60 + + +class PlatformClient: + """HTTP client scoped to a single org via its Platform API key.""" + + def __init__( + self, endpoint: OrgEndpoint, timeout: int = DEFAULT_TIMEOUT, verify: bool = True + ): + self.endpoint = endpoint + self.timeout = timeout + self.verify = verify + self._session = requests.Session() + self._session.headers.update( + { + "Authorization": f"Bearer {endpoint.platform_key}", + "Accept": "application/json", + } + ) + # Cache the OPTIONS-derived writable-field set per entity path. + # Backend serializer is the single source of truth; we read it once. + self._post_schema_cache: dict[str, frozenset[str]] = {} + + def close(self) -> None: + """Release the underlying HTTP connection pool.""" + self._session.close() + + def __enter__(self) -> "PlatformClient": + return self + + def __exit__(self, *exc: Any) -> None: + self.close() + + def _url(self, path: str) -> str: + base = self.endpoint.base_url.rstrip("/") + api_prefix = self.endpoint.api_path_prefix.strip("/") + prefix = f"/{api_prefix}/unstract/{self.endpoint.organization_id}/" + return base + prefix + path.lstrip("/") + + def _request( + self, + method: str, + path: str, + *, + params: dict[str, Any] | None = None, + json: Any = None, + files: dict[str, Any] | None = None, + data: dict[str, Any] | None = None, + ) -> Any: + url = self._url(path) + # Redact secrets from logs: only entity path + method, never body. + logger.debug("%s %s", method, url) + resp = self._session.request( + method, + url, + params=params, + json=json, + files=files, + data=data, + timeout=self.timeout, + verify=self.verify, + ) + if not 200 <= resp.status_code < 300: + raise PlatformAPIError( + f"{method} {path} returned {resp.status_code}", + status_code=resp.status_code, + body=resp.text[:2000], + ) + if resp.status_code == 204 or not resp.content: + return None + return resp.json() + + def get_post_schema(self, entity_path: str) -> frozenset[str]: + """Return the set of fields the backend's POST serializer accepts. + + Reads it from a DRF ``OPTIONS`` response (``actions.POST``) once + per path and caches the result. DRF ``SimpleMetadata`` already + excludes ``read_only`` fields from ``actions.POST``, so the + returned set is exactly the writable subset. + """ + cached = self._post_schema_cache.get(entity_path) + if cached is not None: + return cached + body = self._request("OPTIONS", entity_path) + actions = (body or {}).get("actions") or {} + post_block = actions.get("POST") or {} + writable = frozenset( + name for name, meta in post_block.items() if not meta.get("read_only") + ) + self._post_schema_cache[entity_path] = writable + return writable + + # ----- adapters ----- + + def list_adapters( + self, + *, + name: str | None = None, + adapter_type: str | None = None, + ) -> list[dict[str, Any]]: + """List adapters in this org, optionally filtered by name and/or type.""" + params: dict[str, Any] = {} + if name is not None: + params["adapter_name"] = name + if adapter_type is not None: + params["adapter_type"] = adapter_type + result = self._request("GET", "adapter/", params=params) + # DRF ModelViewSet.list returns a bare list (no pagination on this endpoint). + return result if isinstance(result, list) else result.get("results", []) + + def get_adapter(self, adapter_pk: str) -> dict[str, Any]: + return self._request("GET", f"adapter/{adapter_pk}/") + + def create_adapter(self, payload: dict[str, Any]) -> dict[str, Any]: + return self._request("POST", "adapter/", json=payload) + + # ----- connectors ----- + + def list_connectors( + self, + *, + name: str | None = None, + connector_type: str | None = None, + ) -> list[dict[str, Any]]: + """List connectors in this org, optionally filtered by name and/or type.""" + params: dict[str, Any] = {} + if name is not None: + params["connector_name"] = name + if connector_type is not None: + params["connector_type"] = connector_type + result = self._request("GET", "connector/", params=params) + return result if isinstance(result, list) else result.get("results", []) + + def get_connector(self, connector_pk: str) -> dict[str, Any]: + return self._request("GET", f"connector/{connector_pk}/") + + def create_connector(self, payload: dict[str, Any]) -> dict[str, Any]: + return self._request("POST", "connector/", json=payload) + + # ----- tags ----- + + def list_tags(self, *, name: str | None = None) -> list[dict[str, Any]]: + """List tags in this org, optionally filtered by exact name.""" + params: dict[str, Any] = {} + if name is not None: + params["name"] = name + result = self._request("GET", "tags/", params=params) + # Tags endpoint uses pagination — accept either bare list or paginated envelope. + return result if isinstance(result, list) else result.get("results", []) + + def create_tag(self, payload: dict[str, Any]) -> dict[str, Any]: + return self._request("POST", "tags/", json=payload) + + # ----- custom tools (prompt studio) ----- + + def list_custom_tools(self) -> list[dict[str, Any]]: + """List all prompt-studio projects in this org. No name filter.""" + result = self._request("GET", "prompt-studio/") + return result if isinstance(result, list) else result.get("results", []) + + def get_custom_tool(self, tool_id: str) -> dict[str, Any]: + """Fetch a single prompt-studio project (full serializer). + + Returns ``fields = "__all__"`` per ``CustomToolSerializer`` — + notably includes ``output`` (the default DocumentManager id the + FE binds to ``selectedDoc`` on load). + """ + return self._request("GET", f"prompt-studio/{tool_id}/") + + def update_custom_tool(self, tool_id: str, body: dict[str, Any]) -> dict[str, Any]: + """PATCH a prompt-studio project. Used to set ``output`` (the + default doc id) after the files phase populates DM rows.""" + return self._request("PATCH", f"prompt-studio/{tool_id}/", json=body) + + def list_profiles(self, tool_id: str) -> list[dict[str, Any]]: + """List ProfileManager rows for a tool. + + The clone reads this on the source only — to discover the + default profile's adapter UUIDs so they can be remapped to + target adapter ids for ``import_project``. + """ + result = self._request("GET", f"prompt-studio/prompt-studio-profile/{tool_id}/") + return result if isinstance(result, list) else result.get("results", []) + + def export_project(self, tool_id: str) -> dict[str, Any]: + """Export a prompt-studio project as a portable JSON blob. + + Bundles ``tool_metadata``, ``tool_settings``, + ``default_profile_settings``, ``prompts``, ``export_metadata`` in + one shot — feed straight into ``import_project`` or + ``sync_prompts`` on the target. + """ + return self._request("GET", f"prompt-studio/project-transfer/{tool_id}") + + def import_project( + self, + export_data: dict[str, Any], + adapter_ids: dict[str, str | None] | None = None, + ) -> dict[str, Any]: + """Import a prompt-studio project from an export blob. + + Backend creates the tool, builds the default ProfileManager from + the supplied target-org adapter ids, and imports all prompts in + one call. On name collision the backend silently uniquifies the + new tool's name — callers should pre-check via + ``list_custom_tools`` to avoid that. + + ``adapter_ids`` keys are the backend's form fields: + ``llm_adapter_id``, ``vector_db_adapter_id``, + ``embedding_adapter_id``, ``x2text_adapter_id``. All four + required to wire the profile; otherwise backend falls back to + a profile without adapters and flags ``needs_adapter_config``. + """ + tool_name = export_data.get("tool_metadata", {}).get("tool_name") or "export" + content = json_lib.dumps(export_data).encode() + files = {"file": (f"{tool_name}.json", content, "application/json")} + data: dict[str, Any] = {} + if adapter_ids: + for key in ( + "llm_adapter_id", + "vector_db_adapter_id", + "embedding_adapter_id", + "x2text_adapter_id", + ): + val = adapter_ids.get(key) + if val: + data[key] = val + return self._request( + "POST", + "prompt-studio/project-transfer/", + files=files, + data=data, + ) + + def sync_prompts( + self, + tool_id: str, + export_data: dict[str, Any], + *, + create_copy: bool = False, + ) -> dict[str, Any]: + """Rip-and-replace prompts on an existing target tool. + + Adopt path: target tool already exists with its own + adapter-bound profiles. This overwrites its prompt set (and + ``tool_settings``) from source; profiles and uploaded documents + are left untouched. + """ + payload = {"data": export_data, "create_copy": create_copy} + return self._request( + "POST", f"prompt-studio/{tool_id}/sync-prompts/", json=payload + ) + + def list_prompt_documents(self, tool_id: str) -> list[dict[str, Any]]: + """List DocumentManager rows for a tool. + + Used by FilesPhase for target-side idempotency and source-side + enumeration. Response items carry ``document_id``, + ``document_name``, and ``tool`` (per the serializer's + ``to_representation`` filter). + """ + result = self._request( + "GET", "prompt-studio/prompt-document/", params={"tool_id": tool_id} + ) + return result if isinstance(result, list) else result.get("results", []) + + def download_prompt_file(self, tool_id: str, document_id: str) -> dict[str, Any]: + """GET a Prompt Studio document by tool + DM row id. + + ``fetch_contents_ide`` resolves the filename internally from the + DocumentManager row, so the SDK passes the ``document_id`` it + already has from ``list_prompt_documents`` rather than reposting + the filename. Returns ``{"data": ..., "mime_type": ...}`` — + PDFs base64, text/csv utf-8, Excel placeholder. + """ + return self._request( + "GET", + f"prompt-studio/file/{tool_id}", + params={"document_id": document_id}, + ) + + def upload_prompt_file( + self, + tool_id: str, + file_name: str, + data: bytes, + mime_type: str, + ) -> dict[str, Any]: + """Upload a file into a target Prompt Studio tool. + + Backend writes bytes to storage and creates a ``DocumentManager`` + row. The DM model has ``UniqueConstraint(document_name, tool)``, + so callers must pre-check via ``list_prompt_documents`` to avoid + an IntegrityError → 500 on re-runs. + """ + files = {"file": (file_name, data, mime_type)} + return self._request("POST", f"prompt-studio/file/{tool_id}", files=files) + + def export_custom_tool(self, tool_id: str, *, force: bool = True) -> Any: + """Republish ``PromptStudioRegistry`` from the tool's current state. + + Called after import/sync so the registry row reflects the + freshly landed prompts. Required for ToolInstancePhase to find + a target registry id to remap. + """ + return self._request( + "POST", + f"prompt-studio/export/{tool_id}", + json={ + "is_shared_with_org": False, + "user_id": [], + "force_export": force, + }, + ) + + # ----- workflows ----- + + def list_workflows(self, *, name: str | None = None) -> list[dict[str, Any]]: + """List workflows in this org, optionally filtered by exact name.""" + params: dict[str, Any] = {} + if name is not None: + params["workflow_name"] = name + result = self._request("GET", "workflow/", params=params) + return result if isinstance(result, list) else result.get("results", []) + + def get_workflow(self, workflow_id: str) -> dict[str, Any]: + return self._request("GET", f"workflow/{workflow_id}/") + + def create_workflow(self, payload: dict[str, Any]) -> dict[str, Any]: + """Create a workflow. Backend auto-creates empty WorkflowEndpoints for it.""" + return self._request("POST", "workflow/", json=payload) + + # ----- prompt studio registry ----- + + def list_registries( + self, *, custom_tool: str | None = None + ) -> list[dict[str, Any]]: + """List PromptStudioRegistry rows. The list endpoint returns nothing + unless a filter is supplied; pass ``custom_tool`` to look up the + registry id for a given tool. + """ + params: dict[str, Any] = {} + if custom_tool is not None: + params["custom_tool"] = custom_tool + result = self._request("GET", "prompt-studio/registry/", params=params) + return result if isinstance(result, list) else result.get("results", []) + + # ----- tool instances ----- + + def list_tool_instances( + self, *, workflow_id: str | None = None + ) -> list[dict[str, Any]]: + """List ToolInstance rows, optionally scoped to a workflow.""" + params: dict[str, Any] = {} + if workflow_id is not None: + params["workflow"] = workflow_id + result = self._request("GET", "tool_instance/", params=params) + return result if isinstance(result, list) else result.get("results", []) + + def create_tool_instance(self, payload: dict[str, Any]) -> dict[str, Any]: + """Create a tool instance (max 1 per workflow). The backend overwrites + the ``metadata`` field with tool defaults — caller must PATCH after + create to transfer source metadata. + """ + return self._request("POST", "tool_instance/", json=payload) + + def update_tool_instance_metadata( + self, instance_id: str, metadata: dict[str, Any] + ) -> dict[str, Any]: + """PATCH a tool instance's metadata. Backend resolves adapter names + in the payload to local UUIDs via ``update_instance_metadata``. + """ + return self._request( + "PATCH", f"tool_instance/{instance_id}/", json={"metadata": metadata} + ) + + # ----- workflow endpoints ----- + + def list_workflow_endpoints( + self, *, workflow_id: str | None = None + ) -> list[dict[str, Any]]: + """List workflow endpoints, optionally filtered by workflow id. + + The backend auto-creates one SOURCE and one DESTINATION endpoint + per workflow, so a workflow filter typically returns exactly two + rows. + """ + params: dict[str, Any] = {} + if workflow_id is not None: + params["workflow"] = workflow_id + result = self._request("GET", "workflow/endpoint/", params=params) + return result if isinstance(result, list) else result.get("results", []) + + def update_workflow_endpoint( + self, endpoint_id: str, payload: dict[str, Any] + ) -> dict[str, Any]: + return self._request("PATCH", f"workflow/endpoint/{endpoint_id}/", json=payload) + + # ----- pipelines (ETL / TASK) ----- + + def list_pipelines( + self, + *, + name: str | None = None, + pipeline_type: str | None = None, + ) -> list[dict[str, Any]]: + """List pipelines in this org, optionally filtered by exact name + and/or pipeline_type (``ETL`` / ``TASK`` / ``APP``). + """ + params: dict[str, Any] = {} + if name is not None: + params["pipeline_name"] = name + if pipeline_type is not None: + params["type"] = pipeline_type + result = self._request("GET", "pipeline/", params=params) + return result if isinstance(result, list) else result.get("results", []) + + def get_pipeline(self, pipeline_id: str) -> dict[str, Any]: + return self._request("GET", f"pipeline/{pipeline_id}/") + + def create_pipeline(self, payload: dict[str, Any]) -> dict[str, Any]: + """Create a pipeline. Backend force-sets ``active=True`` and auto-creates + a single active API key on the new pipeline. + """ + return self._request("POST", "pipeline/", json=payload) + + def update_pipeline( + self, pipeline_id: str, payload: dict[str, Any] + ) -> dict[str, Any]: + return self._request("PATCH", f"pipeline/{pipeline_id}/", json=payload) + + # ----- API deployments ----- + + def list_api_deployments( + self, + *, + api_name: str | None = None, + ) -> list[dict[str, Any]]: + """List API deployments in this org, optionally filtered by exact api_name.""" + params: dict[str, Any] = {} + if api_name is not None: + params["api_name"] = api_name + result = self._request("GET", "api/deployment/", params=params) + return result if isinstance(result, list) else result.get("results", []) + + def get_api_deployment(self, deployment_id: str) -> dict[str, Any]: + return self._request("GET", f"api/deployment/{deployment_id}/") + + def create_api_deployment(self, payload: dict[str, Any]) -> dict[str, Any]: + """Create an API deployment. Backend auto-creates a single active key + and returns it in the response under ``api_key``. + """ + return self._request("POST", "api/deployment/", json=payload) + + def update_api_deployment( + self, deployment_id: str, payload: dict[str, Any] + ) -> dict[str, Any]: + return self._request("PATCH", f"api/deployment/{deployment_id}/", json=payload) + + # ----- API keys (per pipeline / deployment) ----- + + def list_pipeline_keys(self, pipeline_id: str) -> list[dict[str, Any]]: + """List API keys belonging to a pipeline.""" + result = self._request("GET", f"api/keys/pipeline/{pipeline_id}/") + return result if isinstance(result, list) else result.get("results", []) + + def list_api_deployment_keys(self, deployment_id: str) -> list[dict[str, Any]]: + """List API keys belonging to an API deployment.""" + result = self._request("GET", f"api/keys/api/{deployment_id}/") + return result if isinstance(result, list) else result.get("results", []) + + def create_api_key(self, payload: dict[str, Any]) -> dict[str, Any]: + """Create an extra API key tied to a pipeline or deployment. + + Used to mirror non-default keys (e.g. an additional rotated key) + on the target. The ``api_key`` UUID itself is server-generated + and cannot be carried over from source. + """ + return self._request("POST", "api/keys/api/", json=payload) diff --git a/src/unstract/clone/context.py b/src/unstract/clone/context.py new file mode 100644 index 0000000..e5d6b0f --- /dev/null +++ b/src/unstract/clone/context.py @@ -0,0 +1,104 @@ +"""Shared state passed between clone phases. + +Three top-level types: + +- ``OrgEndpoint`` — base URL + organization_id + Platform API key for one org. +- ``CloneOptions`` — run flags (dry-run, include/exclude, name-conflict). +- ``CloneContext`` — bundles source/target clients, options, and the + per-run ``RemapTable``. + +``RemapTable`` lives here too because every phase touches it. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from unstract.clone.client import PlatformClient + + +@dataclass(frozen=True) +class OrgEndpoint: + """One end of a clone: where to talk to and who to talk as. + + ``organization_id`` is the slug embedded in the URL path; the bearer + Platform API key must belong to this org. ``api_path_prefix`` matches + the deployment's URL prefix (defaults to ``api/v1``). + """ + + base_url: str + organization_id: str + platform_key: str + api_path_prefix: str = "api/v1" + + +DEFAULT_MAX_FILE_SIZE = 25 * 1024 * 1024 # 25 MB; oversize → manual-upload list +DEFAULT_CONCURRENCY = 4 + + +@dataclass +class CloneOptions: + """Per-run flags for ``clone()``.""" + + dry_run: bool = False + include: tuple[str, ...] | None = None + exclude: tuple[str, ...] = () + on_name_conflict: str = "adopt" # "adopt" | "abort" + verbose: bool = False + # "platform_api": download/upload via existing endpoints (default). + # "skip": metadata only; operator re-uploads via UI on target. + file_strategy: str = "platform_api" + max_file_size: int = DEFAULT_MAX_FILE_SIZE + # Per-phase worker fan-out. 1 = sequential (no executor). + concurrency: int = DEFAULT_CONCURRENCY + + def includes(self, phase_name: str) -> bool: + if self.include is not None and phase_name not in self.include: + return False + return phase_name not in self.exclude + + +class RemapTable: + """Maps source UUID -> target UUID, scoped per entity type. + + Built up in dependency order; consumed by the JSON walker before POST. + ``resolve_any`` lets the walker look up a UUID without knowing its + entity type — necessary because embedded references in JSON payloads + don't always carry an entity hint. + """ + + def __init__(self) -> None: + self._table: dict[str, dict[str, str]] = {} + + def record(self, entity: str, src_uuid: str, tgt_uuid: str) -> None: + self._table.setdefault(entity, {})[src_uuid] = tgt_uuid + + def resolve(self, entity: str, src_uuid: str) -> str | None: + return self._table.get(entity, {}).get(src_uuid) + + def resolve_any(self, src_uuid: str) -> str | None: + for mapping in self._table.values(): + hit = mapping.get(src_uuid) + if hit is not None: + return hit + return None + + def snapshot(self) -> dict[str, dict[str, str]]: + """Read-only snapshot for the post-run report.""" + return {entity: dict(m) for entity, m in self._table.items()} + + +@dataclass +class CloneContext: + """Shared state for one ``clone()`` invocation. + + Phases hold a reference to this and call ``ctx.source`` / ``ctx.target`` + to drive HTTP, ``ctx.remap`` to record UUID mappings. + """ + + source: PlatformClient + target: PlatformClient + options: CloneOptions + remap: RemapTable = field(default_factory=RemapTable) diff --git a/src/unstract/clone/exceptions.py b/src/unstract/clone/exceptions.py new file mode 100644 index 0000000..3933c1c --- /dev/null +++ b/src/unstract/clone/exceptions.py @@ -0,0 +1,25 @@ +"""Exceptions raised by the clone subpackage.""" + + +class CloneError(Exception): + """Base class for all clone errors.""" + + +class PlatformAPIError(CloneError): + """Raised when the Platform API returns a non-2xx response we can't recover from.""" + + def __init__( + self, message: str, status_code: int | None = None, body: str | None = None + ): + full_message = f"{message}\n body: {body}" if body else message + super().__init__(full_message) + self.status_code = status_code + self.body = body + + +class NameConflictError(CloneError): + """Raised when ``on_name_conflict='abort'`` and the target has a like-named entity.""" + + +class DependencyMissingError(CloneError): + """Raised when a phase references a source UUID that no prior phase has mapped.""" diff --git a/src/unstract/clone/orchestrator.py b/src/unstract/clone/orchestrator.py new file mode 100644 index 0000000..a7c81a0 --- /dev/null +++ b/src/unstract/clone/orchestrator.py @@ -0,0 +1,115 @@ +"""Top-level ``clone()`` entry point. + +Wires source/target ``PlatformClient`` instances, builds a +``CloneContext``, runs each phase in strict topological order, and +returns a ``CloneReport``. + +Phase order is owned here — phases must not call each other. Adding a new +entity type means: write a new ``Phase`` subclass and append it to +``PHASES`` at the right dependency position. +""" + +from __future__ import annotations + +import logging +import time + +from unstract.clone.client import PlatformClient +from unstract.clone.context import CloneContext, CloneOptions, OrgEndpoint +from unstract.clone.exceptions import CloneError +from unstract.clone.phases import ( + AdapterPhase, + APIDeploymentPhase, + ConnectorPhase, + CustomToolPhase, + FilesPhase, + PipelinePhase, + TagPhase, + ToolInstancePhase, + WorkflowEndpointPhase, + WorkflowPhase, +) +from unstract.clone.phases.base import Phase +from unstract.clone.report import CloneReport, Endpoint + +logger = logging.getLogger(__name__) + +# Strict dependency order. Each entry: (phase_name, phase_class). +# Adapter, connector, tag are independent leaf phases. Downstream phases +# (custom_tool, workflow, tool_instance, workflow_endpoint) land later +# and consume the remap entries these produce. Pipeline + api_deployment +# come last: both FK the workflow and api_deployment additionally +# requires endpoints to be configured before the serializer accepts it. +PHASES: list[tuple[str, type[Phase]]] = [ + ("adapter", AdapterPhase), + ("connector", ConnectorPhase), + ("tag", TagPhase), + ("custom_tool", CustomToolPhase), + ("files", FilesPhase), + ("workflow", WorkflowPhase), + ("tool_instance", ToolInstancePhase), + ("workflow_endpoint", WorkflowEndpointPhase), + ("pipeline", PipelinePhase), + ("api_deployment", APIDeploymentPhase), +] + + +def clone( + source: OrgEndpoint, + target: OrgEndpoint, + options: CloneOptions | None = None, +) -> CloneReport: + """Migrate configured resources from one org to another. + + Returns a ``CloneReport`` even on partial failure; raises only on + setup errors or ``on_name_conflict='abort'`` collisions. + """ + opts = options or CloneOptions() + src_client = PlatformClient(source) + tgt_client = PlatformClient(target) + try: + ctx = CloneContext( + source=src_client, + target=tgt_client, + options=opts, + ) + report = CloneReport( + source=Endpoint( + base_url=source.base_url, organization_id=source.organization_id + ), + target=Endpoint( + base_url=target.base_url, organization_id=target.organization_id + ), + ) + + run_started = time.perf_counter() + for name, phase_cls in PHASES: + if not opts.includes(name): + report.skipped_phases.append(name) + logger.info("Phase '%s' skipped (excluded)", name) + continue + logger.info("=== Phase: %s ===", name) + phase_started = time.perf_counter() + try: + phase_cls(ctx).run(report) + except CloneError as e: + report.aborted = True + report.abort_reason = str(e) + logger.error("Phase '%s' aborted: %s", name, e) + # Stamp duration even on abort so the report reflects time spent. + report.get_phase(name).duration_s = time.perf_counter() - phase_started + break + else: + report.get_phase(name).duration_s = time.perf_counter() - phase_started + logger.info( + "=== Phase '%s' done in %.2fs ===", + name, + report.get_phase(name).duration_s, + ) + + report.total_duration_s = time.perf_counter() - run_started + report.remap_snapshot = ctx.remap.snapshot() + return report + finally: + src_client.close() + tgt_client.close() diff --git a/src/unstract/clone/phases/__init__.py b/src/unstract/clone/phases/__init__.py new file mode 100644 index 0000000..03f0952 --- /dev/null +++ b/src/unstract/clone/phases/__init__.py @@ -0,0 +1,34 @@ +"""Per-entity clone phases. + +Each phase implements ``run(report)``, uses ``ctx.source`` / ``ctx.target`` +to drive HTTP, records ``ctx.remap`` entries for downstream phases. + +Dependency order is owned by ``orchestrator.clone`` — phases must NOT +call each other directly. +""" + +from unstract.clone.phases.adapter import AdapterPhase +from unstract.clone.phases.api_deployment import APIDeploymentPhase +from unstract.clone.phases.base import Phase +from unstract.clone.phases.connector import ConnectorPhase +from unstract.clone.phases.custom_tool import CustomToolPhase +from unstract.clone.phases.files import FilesPhase +from unstract.clone.phases.pipeline import PipelinePhase +from unstract.clone.phases.tag import TagPhase +from unstract.clone.phases.tool_instance import ToolInstancePhase +from unstract.clone.phases.workflow import WorkflowPhase +from unstract.clone.phases.workflow_endpoint import WorkflowEndpointPhase + +__all__ = [ + "APIDeploymentPhase", + "AdapterPhase", + "ConnectorPhase", + "CustomToolPhase", + "FilesPhase", + "Phase", + "PipelinePhase", + "TagPhase", + "ToolInstancePhase", + "WorkflowEndpointPhase", + "WorkflowPhase", +] diff --git a/src/unstract/clone/phases/adapter.py b/src/unstract/clone/phases/adapter.py new file mode 100644 index 0000000..522629f --- /dev/null +++ b/src/unstract/clone/phases/adapter.py @@ -0,0 +1,125 @@ +"""Migrate adapters from source org to target org. + +Reference implementation for the get-or-create pattern: list-by-name GET +against target, POST create if missing, record source->target UUID in the +remap table for downstream phases. + +Frictionless onboarding adapters are excluded — the backend's +service-account queryset already filters them out, so clone never +sees them. +""" + +from __future__ import annotations + +import logging +import threading +from typing import Any + +from unstract.clone.exceptions import NameConflictError +from unstract.clone.phases.base import Phase, build_post_payload +from unstract.clone.report import CloneReport, PhaseResult + +logger = logging.getLogger(__name__) + +ADAPTER_PATH = "adapter/" + + +class AdapterPhase(Phase): + name = "adapter" + + def run(self, report: CloneReport) -> PhaseResult: + result = report.get_phase(self.name) + try: + self._writable = self.ctx.target.get_post_schema(ADAPTER_PATH) + except Exception as e: + logger.exception("Failed to fetch target POST schema for adapter: %s", e) + result.failed += 1 + result.errors.append(f"OPTIONS adapter: {e}") + return result + try: + src_summaries = self.ctx.source.list_adapters() + except Exception as e: + logger.exception("Failed to list source adapters: %s", e) + result.failed += 1 + result.errors.append(f"list source adapters: {e}") + return result + + logger.info("Found %d adapter(s) in source org", len(src_summaries)) + self.parallel_map( + src_summaries, + lambda summary, lock: self._clone_one(summary, result, lock), + ) + return result + + def _clone_one( + self, summary: dict[str, Any], result: PhaseResult, lock: threading.Lock + ) -> None: + name = summary["adapter_name"] + atype = summary["adapter_type"] + src_id = summary["id"] + try: + src = self.ctx.source.get_adapter(src_id) + except Exception as e: + logger.exception( + "Failed to GET source adapter %s [%s] detail: %s", name, atype, e + ) + with lock: + result.failed += 1 + result.errors.append(f"GET source detail {name} [{atype}]: {e}") + return + + try: + existing = self.ctx.target.list_adapters(name=name, adapter_type=atype) + except Exception as e: + logger.exception( + "Failed to GET adapter %s [%s] on target: %s", name, atype, e + ) + with lock: + result.failed += 1 + result.errors.append(f"GET {name} [{atype}]: {e}") + return + + if existing: + tgt = existing[0] + if self.ctx.options.on_name_conflict == "abort": + raise NameConflictError( + f"adapter '{name}' [{atype}] already exists in target as {tgt['id']}" + ) + with lock: + result.adopted += 1 + logger.info( + "adopted adapter '%s' [%s] src=%s -> tgt=%s", + name, + atype, + src_id, + tgt["id"], + ) + elif self.ctx.options.dry_run: + with lock: + result.skipped += 1 + logger.info( + "[dry-run] would create adapter '%s' [%s] src=%s", name, atype, src_id + ) + return + else: + payload = build_post_payload(src, self._writable) + try: + tgt = self.ctx.target.create_adapter(payload) + except Exception as e: + logger.exception("Failed to create adapter %s [%s]: %s", name, atype, e) + with lock: + result.failed += 1 + result.errors.append(f"create {name} [{atype}]: {e}") + return + with lock: + result.created += 1 + logger.info( + "created adapter '%s' [%s] src=%s -> tgt=%s", + name, + atype, + src_id, + tgt["id"], + ) + + with lock: + self.ctx.remap.record("adapter", src_id, tgt["id"]) diff --git a/src/unstract/clone/phases/api_deployment.py b/src/unstract/clone/phases/api_deployment.py new file mode 100644 index 0000000..df55983 --- /dev/null +++ b/src/unstract/clone/phases/api_deployment.py @@ -0,0 +1,175 @@ +"""Migrate API deployments from source org to target org. + +APIDeployment FKs ``workflow`` — remap via the WorkflowPhase table. +Backend enforces one active deployment per workflow and one +``api_name`` per organization, so adopt-by-name is the only safe +re-run strategy. + +On create the backend auto-provisions a single active API key and +returns it on the response. Extra rotated keys on the source are NOT +mirrored (server-generated UUIDs can't be preserved; rotate +post-clone). +""" + +from __future__ import annotations + +import logging +import threading +from typing import Any + +from unstract.clone.exceptions import NameConflictError +from unstract.clone.phases.base import Phase, build_post_payload +from unstract.clone.report import CloneReport, PhaseResult +from unstract.clone.walker import remap_uuids + +logger = logging.getLogger(__name__) + +API_DEPLOYMENT_PATH = "api/deployment/" + + +class APIDeploymentPhase(Phase): + name = "api_deployment" + + def run(self, report: CloneReport) -> PhaseResult: + result = report.get_phase(self.name) + try: + self._writable = self.ctx.target.get_post_schema(API_DEPLOYMENT_PATH) + except Exception as e: + logger.exception( + "Failed to fetch target POST schema for api_deployment: %s", e + ) + result.failed += 1 + result.errors.append(f"OPTIONS api_deployment: {e}") + return result + + try: + src_deployments = self.ctx.source.list_api_deployments() + except Exception as e: + logger.exception("Failed to list source api_deployments: %s", e) + result.failed += 1 + result.errors.append(f"list source api_deployments: {e}") + return result + + logger.info("Found %d source API deployment(s)", len(src_deployments)) + self.parallel_map( + src_deployments, + lambda src, lock: self._clone_one(src, result, lock), + ) + return result + + def _clone_one( + self, src: dict[str, Any], result: PhaseResult, lock: threading.Lock + ) -> None: + api_name = src["api_name"] + src_id = src["id"] + src_wf_id = src.get("workflow") or src.get("workflow_id") + + if not src_wf_id: + logger.warning( + "source api_deployment '%s' has no workflow FK — skipping", api_name + ) + with lock: + result.skipped += 1 + return + + with lock: + tgt_wf_id = self.ctx.remap.resolve("workflow", src_wf_id) + if not tgt_wf_id: + logger.warning( + "no workflow remap for api_deployment '%s' (src workflow %s) — skipping", + api_name, + src_wf_id, + ) + with lock: + result.skipped += 1 + return + + try: + existing = self.ctx.target.list_api_deployments(api_name=api_name) + except Exception as e: + logger.exception( + "Failed to GET api_deployment %s on target: %s", api_name, e + ) + with lock: + result.failed += 1 + result.errors.append(f"GET {api_name}: {e}") + return + + if existing: + tgt = existing[0] + if self.ctx.options.on_name_conflict == "abort": + raise NameConflictError( + f"api_deployment '{api_name}' already exists in target as {tgt['id']}" + ) + with lock: + result.adopted += 1 + logger.info( + "adopted api_deployment '%s' src=%s -> tgt=%s", + api_name, + src_id, + tgt["id"], + ) + elif self.ctx.options.dry_run: + with lock: + result.skipped += 1 + logger.info( + "[dry-run] would create api_deployment '%s' src=%s", api_name, src_id + ) + return + else: + try: + full_src = self.ctx.source.get_api_deployment(src_id) + except Exception as e: + logger.exception( + "Failed to GET source api_deployment %s: %s", api_name, e + ) + with lock: + result.failed += 1 + result.errors.append(f"GET src api_deployment {api_name}: {e}") + return + remapped = remap_uuids(full_src, self.ctx.remap) + payload = build_post_payload(remapped, self._writable) + payload["workflow"] = tgt_wf_id + try: + tgt = self.ctx.target.create_api_deployment(payload) + except Exception as e: + logger.exception("Failed to create api_deployment %s: %s", api_name, e) + with lock: + result.failed += 1 + result.errors.append(f"create {api_name}: {e}") + return + with lock: + result.created += 1 + logger.info( + "created api_deployment '%s' src=%s -> tgt=%s", + api_name, + src_id, + tgt["id"], + ) + self._warn_if_extra_source_keys(src_id, api_name) + + with lock: + self.ctx.remap.record("api_deployment", src_id, tgt["id"]) + + def _warn_if_extra_source_keys(self, src_deployment_id: str, name: str) -> None: + try: + keys = self.ctx.source.list_api_deployment_keys(src_deployment_id) + except Exception as e: + # WARNING (not DEBUG) — the operator needs to know we couldn't + # check whether they have additional keys to recreate manually. + logger.warning( + "Could not list source keys for api_deployment %s " + "(extra-key check skipped; re-verify in source UI): %s", + name, + e, + ) + return + active = [k for k in keys if k.get("is_active")] + if len(active) > 1: + logger.warning( + "source api_deployment '%s' had %d active API keys; " + "target has only the auto-provisioned default — " + "re-create the rest manually if your clients depend on them", + name, + len(active), + ) diff --git a/src/unstract/clone/phases/base.py b/src/unstract/clone/phases/base.py new file mode 100644 index 0000000..c14f1bc --- /dev/null +++ b/src/unstract/clone/phases/base.py @@ -0,0 +1,114 @@ +"""Base class for clone phases.""" + +from __future__ import annotations + +import logging +import threading +from abc import ABC, abstractmethod +from collections.abc import Callable, Iterable +from concurrent.futures import FIRST_EXCEPTION, Future, ThreadPoolExecutor, wait +from typing import Any, TypeVar + +from unstract.clone.context import CloneContext +from unstract.clone.exceptions import CloneError +from unstract.clone.report import CloneReport, PhaseResult + +T = TypeVar("T") + +logger = logging.getLogger(__name__) + +# DRF OPTIONS reports any ModelSerializer FK/M2M as writable, but the +# backend's perform_create overrides these server-side. Posting them is +# either noise (silently overwritten) or a 400 (when a source-org value +# doesn't validate against the target org). Strip them universally — +# the phase OPTIONS schema covers the entity-specific writable subset. +SERVER_MANAGED: frozenset[str] = frozenset( + { + "id", + "organization", + "created_by", + "created_by_email", + "modified_by", + "modified_by_email", + "created_at", + "modified_at", + "shared_users", + } +) + + +def build_post_payload(src: dict[str, Any], writable: frozenset[str]) -> dict[str, Any]: + """Project ``src`` onto the writable schema, dropping server-managed + fields, ``None`` values, and empty strings (which DRF treats as blank + and rejects on required fields). + """ + keys = writable - SERVER_MANAGED + # Equality with `(None, "")` matched False and 0 too (Python: False == 0, + # 0 in (None, "") is False, but `0 not in (...)` falsely returns True). + # Explicit identity / equality checks preserve falsy-but-meaningful + # values like ``BooleanField`` False and numeric defaults. + return {k: src[k] for k in keys if k in src and src[k] is not None and src[k] != ""} + + +class Phase(ABC): + """Abstract phase. One subclass per entity type.""" + + name: str = "" + + def __init__(self, ctx: CloneContext): + self.ctx = ctx + + @abstractmethod + def run(self, report: CloneReport) -> PhaseResult: + """Migrate all entities of this phase's type. Idempotent across runs.""" + raise NotImplementedError + + def parallel_map( + self, + items: Iterable[T], + work_fn: Callable[[T, threading.Lock], None], + ) -> None: + """Fan ``work_fn(item, lock)`` across ``ctx.options.concurrency`` + threads. ``work_fn`` must hold ``lock`` while mutating shared + state. ``CloneError`` from any worker cancels the rest and + re-raises. ``concurrency <= 1`` skips the executor entirely. + """ + materialised = list(items) + if not materialised: + return + + concurrency = max(1, self.ctx.options.concurrency) + lock = threading.Lock() + + if concurrency == 1: + for item in materialised: + work_fn(item, lock) + return + + with ThreadPoolExecutor( + max_workers=concurrency, + thread_name_prefix=f"clone-{self.name}", + ) as pool: + futures: list[Future[None]] = [ + pool.submit(work_fn, item, lock) for item in materialised + ] + done, _ = wait(futures, return_when=FIRST_EXCEPTION) + clone_err: CloneError | None = None + other_err: BaseException | None = None + for fut in done: + if fut.cancelled(): + continue + exc = fut.exception() + if exc is None: + continue + if isinstance(exc, CloneError) and clone_err is None: + clone_err = exc + elif other_err is None: + other_err = exc + if clone_err is not None or other_err is not None: + for fut in futures: + fut.cancel() + if clone_err is not None: + raise clone_err + if other_err is not None: + raise other_err diff --git a/src/unstract/clone/phases/connector.py b/src/unstract/clone/phases/connector.py new file mode 100644 index 0000000..5a1c173 --- /dev/null +++ b/src/unstract/clone/phases/connector.py @@ -0,0 +1,136 @@ +"""Migrate connectors from source org to target org. + +Same list -> per-id GET -> POST/adopt pattern as AdapterPhase. Two +connector-specific wrinkles: + +1. **Connectors with redacted metadata are skipped.** The backend + serializer strips ``connector_metadata`` for auto-provisioned rows + (e.g. Unstract Cloud Storage), so the SDK cannot reconstruct them + on the target. We detect this by inspecting the source GET response: + a falsy ``connector_metadata`` means the operator must rely on the + target's own provisioning (or re-create the row manually) — the + remap table records no entry for these. + +2. **OAuth ``connector_auth`` is stripped from responses.** Tokens are + stored in a sibling ``ConnectorAuth`` row that the public API never + exposes, so OAuth-backed connectors land on the target without + refresh tokens. Operator must re-authorise on target. +""" + +from __future__ import annotations + +import logging +import threading +from typing import Any + +from unstract.clone.exceptions import NameConflictError +from unstract.clone.phases.base import Phase, build_post_payload +from unstract.clone.report import CloneReport, PhaseResult + +logger = logging.getLogger(__name__) + +CONNECTOR_PATH = "connector/" + + +class ConnectorPhase(Phase): + name = "connector" + + def run(self, report: CloneReport) -> PhaseResult: + result = report.get_phase(self.name) + try: + self._writable = self.ctx.target.get_post_schema(CONNECTOR_PATH) + except Exception as e: + logger.exception("Failed to fetch target POST schema for connector: %s", e) + result.failed += 1 + result.errors.append(f"OPTIONS connector: {e}") + return result + try: + src_summaries = self.ctx.source.list_connectors() + except Exception as e: + logger.exception("Failed to list source connectors: %s", e) + result.failed += 1 + result.errors.append(f"list source connectors: {e}") + return result + + logger.info("Found %d connector(s) in source org", len(src_summaries)) + self.parallel_map( + src_summaries, + lambda summary, lock: self._clone_one(summary, result, lock), + ) + return result + + def _clone_one( + self, summary: dict[str, Any], result: PhaseResult, lock: threading.Lock + ) -> None: + name = summary["connector_name"] + src_id = summary["id"] + + try: + src = self.ctx.source.get_connector(src_id) + except Exception as e: + logger.exception("Failed to GET source connector %s detail: %s", name, e) + with lock: + result.failed += 1 + result.errors.append(f"GET source detail {name}: {e}") + return + + if not src.get("connector_metadata"): + logger.info( + "skipping connector '%s' (src=%s, catalog=%s) — source returned no metadata", + name, + src_id, + src.get("connector_id"), + ) + with lock: + result.skipped += 1 + return + + try: + existing = self.ctx.target.list_connectors(name=name) + except Exception as e: + logger.exception("Failed to GET connector %s on target: %s", name, e) + with lock: + result.failed += 1 + result.errors.append(f"GET {name}: {e}") + return + + if existing: + tgt = existing[0] + if self.ctx.options.on_name_conflict == "abort": + raise NameConflictError( + f"connector '{name}' already exists in target as {tgt['id']}" + ) + with lock: + result.adopted += 1 + logger.info( + "adopted connector '%s' src=%s -> tgt=%s", + name, + src_id, + tgt["id"], + ) + elif self.ctx.options.dry_run: + with lock: + result.skipped += 1 + logger.info("[dry-run] would create connector '%s' src=%s", name, src_id) + return + else: + payload = build_post_payload(src, self._writable) + try: + tgt = self.ctx.target.create_connector(payload) + except Exception as e: + logger.exception("Failed to create connector %s: %s", name, e) + with lock: + result.failed += 1 + result.errors.append(f"create {name}: {e}") + return + with lock: + result.created += 1 + logger.info( + "created connector '%s' src=%s -> tgt=%s", + name, + src_id, + tgt["id"], + ) + + with lock: + self.ctx.remap.record("connector", src_id, tgt["id"]) diff --git a/src/unstract/clone/phases/custom_tool.py b/src/unstract/clone/phases/custom_tool.py new file mode 100644 index 0000000..2c80112 --- /dev/null +++ b/src/unstract/clone/phases/custom_tool.py @@ -0,0 +1,327 @@ +"""Migrate prompt-studio projects via the project-transfer endpoints. + +For each source tool the phase: + +1. ``GET prompt-studio/project-transfer/{src_tool_id}`` — pulls a + portable JSON blob (tool_metadata, tool_settings, + default_profile_settings, prompts, export_metadata). +2. Decides fresh vs adopt by looking up the target tool by name. +3. **Fresh path**: reads source's default ProfileManager to learn the + adapter UUIDs the profile is bound to, remaps each via the running + ``adapter`` remap table, and POSTs the import as a multipart upload + with target-org adapter ids on the form. Backend creates the tool, + the default profile, and all prompts server-side in one call. +4. **Adopt path**: POSTs ``sync-prompts`` on the existing target tool. + Backend rip-and-replaces prompts + ``tool_settings`` and leaves the + target's locally-configured profiles + adapters untouched (which is + what the operator wants — they may have rewired adapters on target). +5. Republishes ``PromptStudioRegistry`` via the export action and + records the ``custom_tool`` + ``prompt_studio_registry`` remaps so + downstream ToolInstancePhase can rewrite ``ToolInstance.tool_id``. + +Adapter id discovery for the fresh path needs all four of LLM, +vector_db, embedding, x2text. If any source adapter can't be resolved +via the adapter remap, the tool is failed cleanly — we never want to +land a half-wired profile. +""" + +from __future__ import annotations + +import logging +import threading +from typing import Any + +from unstract.clone.exceptions import NameConflictError +from unstract.clone.phases.base import Phase +from unstract.clone.report import CloneReport, PhaseResult + +logger = logging.getLogger(__name__) + +_PROFILE_ADAPTER_FIELDS: tuple[tuple[str, str], ...] = ( + ("llm", "llm_adapter_id"), + ("vector_store", "vector_db_adapter_id"), + ("embedding_model", "embedding_adapter_id"), + ("x2text", "x2text_adapter_id"), +) + + +def _extract_adapter_name(value: Any) -> str | None: + """Adapter FKs serialise as the adapter NAME on the wire; tolerate a + nested-dict shape too. Never fall back to the UUID — list_adapters + matches by name and would silently miss. + """ + if isinstance(value, str): + return value or None + if isinstance(value, dict): + return value.get("adapter_name") or value.get("name") + return None + + +class CustomToolPhase(Phase): + name = "custom_tool" + + def run(self, report: CloneReport) -> PhaseResult: + result = report.get_phase(self.name) + try: + src_tools = self.ctx.source.list_custom_tools() + except Exception as e: + logger.exception("Failed to list source custom tools: %s", e) + result.failed += 1 + result.errors.append(f"list source custom tools: {e}") + return result + + logger.info("Found %d custom tool(s) in source org", len(src_tools)) + try: + target_tools = self.ctx.target.list_custom_tools() + except Exception as e: + logger.exception("Failed to list target tools: %s", e) + result.failed += 1 + result.errors.append(f"list target tools: {e}") + return result + + # Updated under lock when a fresh create lands so duplicate + # same-name source rows adopt instead of recreating. + target_by_name: dict[str, dict[str, Any]] = { + t["tool_name"]: t for t in target_tools + } + + self.parallel_map( + src_tools, + lambda summary, lock: self._clone_one( + summary, target_by_name, result, lock + ), + ) + return result + + def _clone_one( + self, + summary: dict[str, Any], + target_by_name: dict[str, dict[str, Any]], + result: PhaseResult, + lock: threading.Lock, + ) -> None: + tool_name = summary["tool_name"] + src_tool_id = summary["tool_id"] + + try: + export_data = self.ctx.source.export_project(src_tool_id) + except Exception as e: + logger.exception("Failed to export source tool '%s': %s", tool_name, e) + with lock: + result.failed += 1 + result.errors.append(f"export src tool {tool_name}: {e}") + return + + with lock: + match = target_by_name.get(tool_name) + + if match is not None: + tgt_tool_id = self._adopt( + match, export_data, result, tool_name, src_tool_id, lock + ) + else: + tgt_tool_id = self._create_fresh( + export_data, src_tool_id, tool_name, result, lock + ) + if tgt_tool_id is not None: + with lock: + target_by_name[tool_name] = { + "tool_id": tgt_tool_id, + "tool_name": tool_name, + } + + if tgt_tool_id is None: + return + + with lock: + self.ctx.remap.record("custom_tool", src_tool_id, tgt_tool_id) + + if self.ctx.options.dry_run: + return + + try: + self.ctx.target.export_custom_tool(tgt_tool_id) + logger.info( + "republished registry for tool '%s' tgt=%s", tool_name, tgt_tool_id + ) + except Exception as e: + logger.exception("Registry republish failed for tool %s: %s", tool_name, e) + with lock: + result.failed += 1 + result.errors.append(f"export {tool_name}: {e}") + return + + try: + src_regs = self.ctx.source.list_registries(custom_tool=src_tool_id) + tgt_regs = self.ctx.target.list_registries(custom_tool=tgt_tool_id) + except Exception as e: + logger.warning( + "registry remap lookup failed for tool '%s' " + "(downstream ToolInstance clone may skip): %s", + tool_name, + e, + ) + with lock: + result.failed += 1 + result.errors.append(f"registry remap lookup {tool_name}: {e}") + return + + if src_regs and tgt_regs: + with lock: + self.ctx.remap.record( + "prompt_studio_registry", + src_regs[0]["prompt_registry_id"], + tgt_regs[0]["prompt_registry_id"], + ) + + def _adopt( + self, + match: dict[str, Any], + export_data: dict[str, Any], + result: PhaseResult, + tool_name: str, + src_tool_id: str, + lock: threading.Lock, + ) -> str | None: + if self.ctx.options.on_name_conflict == "abort": + raise NameConflictError( + f"tool '{tool_name}' already exists in target as {match['tool_id']}" + ) + + tgt_tool_id = match["tool_id"] + if self.ctx.options.dry_run: + with lock: + result.skipped += 1 + logger.info( + "[dry-run] would sync prompts into adopted tool '%s' src=%s -> tgt=%s", + tool_name, + src_tool_id, + tgt_tool_id, + ) + return tgt_tool_id + + try: + self.ctx.target.sync_prompts(tgt_tool_id, export_data) + except Exception as e: + logger.exception("sync_prompts failed for tool %s: %s", tool_name, e) + with lock: + result.failed += 1 + result.errors.append(f"sync {tool_name}: {e}") + return None + + with lock: + result.adopted += 1 + logger.info( + "adopted tool '%s' src=%s -> tgt=%s (prompts re-synced)", + tool_name, + src_tool_id, + tgt_tool_id, + ) + return tgt_tool_id + + def _create_fresh( + self, + export_data: dict[str, Any], + src_tool_id: str, + tool_name: str, + result: PhaseResult, + lock: threading.Lock, + ) -> str | None: + if self.ctx.options.dry_run: + with lock: + result.skipped += 1 + logger.info( + "[dry-run] would import tool '%s' src=%s", tool_name, src_tool_id + ) + return None + + adapter_ids = self._resolve_target_adapter_ids(src_tool_id, tool_name) + if adapter_ids is None: + with lock: + result.failed += 1 + result.errors.append( + f"import {tool_name}: missing target adapter remap for default profile" + ) + return None + + try: + tgt = self.ctx.target.import_project(export_data, adapter_ids=adapter_ids) + except Exception as e: + logger.exception("import_project failed for tool %s: %s", tool_name, e) + with lock: + result.failed += 1 + result.errors.append(f"import {tool_name}: {e}") + return None + + tgt_tool_id = tgt["tool_id"] + with lock: + result.created += 1 + logger.info( + "created tool '%s' src=%s -> tgt=%s (needs_adapter_config=%s)", + tool_name, + src_tool_id, + tgt_tool_id, + tgt.get("needs_adapter_config"), + ) + return tgt_tool_id + + def _resolve_target_adapter_ids( + self, src_tool_id: str, tool_name: str + ) -> dict[str, str] | None: + """Source profile carries adapter NAMES (per serializer); resolve + each name to a target adapter UUID via ``list_adapters(name=...)``. + + Returns ``None`` if any of the four required adapters can't be + found on target — caller fails the tool. AdapterPhase preserves + names across orgs so this lookup should always hit when the + adapter clone ran cleanly. + """ + try: + src_profiles = self.ctx.source.list_profiles(src_tool_id) + except Exception as e: + logger.exception( + "Failed to list source profiles for tool %s: %s", tool_name, e + ) + return None + + default = next( + (p for p in src_profiles if p.get("is_default")), + src_profiles[0] if src_profiles else None, + ) + if default is None: + logger.warning( + "source tool '%s' has no profiles to derive adapter ids from", + tool_name, + ) + return None + + resolved: dict[str, str] = {} + for src_field, form_field in _PROFILE_ADAPTER_FIELDS: + adapter_name = _extract_adapter_name(default.get(src_field)) + if not adapter_name: + logger.warning( + "source default profile for tool '%s' missing adapter '%s'", + tool_name, + src_field, + ) + return None + try: + matches = self.ctx.target.list_adapters(name=adapter_name) + except Exception as e: + logger.exception( + "list_adapters lookup failed for %s on tool '%s': %s", + adapter_name, + tool_name, + e, + ) + return None + if not matches: + logger.warning( + "no target adapter named '%s' for field %s on tool '%s'", + adapter_name, + src_field, + tool_name, + ) + return None + resolved[form_field] = matches[0]["id"] + return resolved diff --git a/src/unstract/clone/phases/files.py b/src/unstract/clone/phases/files.py new file mode 100644 index 0000000..8404651 --- /dev/null +++ b/src/unstract/clone/phases/files.py @@ -0,0 +1,507 @@ +"""Migrate Prompt Studio document files (the user-uploaded test corpus). + +Runs after ``CustomToolPhase`` — consumes the ``custom_tool`` remap to +know which source-tool to target-tool mapping to iterate. + +Default mode (``file_strategy='platform_api'``): + +1. For each ``(src_tool_id, tgt_tool_id)``, list source DM rows + target + DM rows once each. +2. For each source filename missing on target: download from source, decode + per mime, enforce the size cap, upload as multipart to target. +3. Oversize files → ``CloneReport.oversize_files``; mime types the + backend can't round-trip (Excel placeholder, etc) → + ``unsupported_files``; transport errors → ``failed_files``. + +Skip mode (``file_strategy='skip'``): + +- No download/upload. Source DM list is emitted into ``skipped_files`` so + the operator knows what to re-upload manually via UI. + +Per-file work fans out across ``ctx.options.concurrency`` workers. +""" + +from __future__ import annotations + +import base64 +import logging +import threading +import time +from dataclasses import dataclass +from typing import Any + +import requests + +from unstract.clone.exceptions import PlatformAPIError +from unstract.clone.phases.base import Phase +from unstract.clone.report import CloneReport, PhaseResult + +logger = logging.getLogger(__name__) + +_BASE64_MIMES: frozenset[str] = frozenset({"application/pdf"}) +_TEXT_MIMES: frozenset[str] = frozenset({"text/plain", "text/csv"}) + +_RETRYABLE_STATUS: frozenset[int] = frozenset({502, 503, 504}) +_MAX_RETRIES = 3 +_RETRY_BACKOFF_BASE_SECONDS = 1.0 + + +@dataclass +class _FileTask: + src_tool_id: str + tgt_tool_id: str + tool_name: str + file_name: str + src_document_id: str + + +class FilesPhase(Phase): + name = "files" + + def run(self, report: CloneReport) -> PhaseResult: + result = report.get_phase(self.name) + tool_remap = self.ctx.remap.snapshot().get("custom_tool", {}) + if not tool_remap: + logger.info("files phase: no custom_tool remap entries; nothing to do") + return result + + strategy = self.ctx.options.file_strategy + logger.info( + "files phase: strategy=%s tools=%d cap=%d bytes concurrency=%d", + strategy, + len(tool_remap), + self.ctx.options.max_file_size, + self.ctx.options.concurrency, + ) + + # Pass 1: build per-file task list sequentially (cheap). + file_tasks: list[_FileTask] = [] + cloned_tools: list[tuple[str, str, str, list[dict[str, Any]]]] = [] + for src_tool_id, tgt_tool_id in tool_remap.items(): + tool_name = self._lookup_tool_name(tgt_tool_id) or src_tool_id + try: + src_docs = self.ctx.source.list_prompt_documents(src_tool_id) + except Exception as e: + logger.exception( + "files: failed to list source DM rows for tool %s: %s", + tool_name, + e, + ) + result.failed += 1 + result.errors.append(f"list source docs {tool_name}: {e}") + continue + + if strategy == "skip": + self._emit_skip( + src_docs, src_tool_id, tgt_tool_id, tool_name, report, result + ) + continue + + tasks = self._build_tool_tasks( + src_tool_id, tgt_tool_id, tool_name, src_docs, report, result + ) + file_tasks.extend(tasks) + cloned_tools.append((src_tool_id, tgt_tool_id, tool_name, src_docs)) + + # Pass 2: download + upload each file in parallel. + if file_tasks: + self.parallel_map( + file_tasks, + lambda task, lock: self._clone_one_file(task, report, result, lock), + ) + + # Pass 3: set default doc per tool after all uploads land. + if not self.ctx.options.dry_run and strategy != "skip": + for src_tool_id, tgt_tool_id, tool_name, src_docs in cloned_tools: + self._ensure_default_doc(src_tool_id, tgt_tool_id, tool_name, src_docs) + + return result + + def _build_tool_tasks( + self, + src_tool_id: str, + tgt_tool_id: str, + tool_name: str, + src_docs: list[dict[str, Any]], + report: CloneReport, + result: PhaseResult, + ) -> list[_FileTask]: + try: + tgt_docs = self.ctx.target.list_prompt_documents(tgt_tool_id) + except Exception as e: + logger.exception( + "files: failed to list target DM rows for tool %s: %s", + tool_name, + e, + ) + result.failed += 1 + result.errors.append(f"list target docs {tool_name}: {e}") + return [] + target_names = {d["document_name"] for d in tgt_docs} + + tasks: list[_FileTask] = [] + for doc in src_docs: + file_name = doc.get("document_name") + src_document_id = doc.get("document_id") + if not file_name or not src_document_id: + result.skipped += 1 + result.errors.append( + f"malformed source DM row on tool={tool_name}: {doc!r}" + ) + logger.warning( + "files: skipping malformed source DM row on tool=%s: %r", + tool_name, + doc, + ) + continue + if file_name in target_names: + result.skipped += 1 + logger.debug( + "files: already present on target tool=%s file=%s", + tool_name, + file_name, + ) + continue + if self.ctx.options.dry_run: + result.skipped += 1 + logger.info( + "[dry-run] files: would clone tool=%s file=%s", + tool_name, + file_name, + ) + continue + tasks.append( + _FileTask( + src_tool_id=src_tool_id, + tgt_tool_id=tgt_tool_id, + tool_name=tool_name, + file_name=file_name, + src_document_id=src_document_id, + ) + ) + return tasks + + def _clone_one_file( + self, + task: _FileTask, + report: CloneReport, + result: PhaseResult, + lock: threading.Lock, + ) -> None: + try: + payload = self._with_retry( + lambda: self.ctx.source.download_prompt_file( + task.src_tool_id, task.src_document_id + ), + op=f"download {task.tool_name}/{task.file_name}", + ) + except Exception as e: + logger.exception( + "files: download failed tool=%s file=%s: %s", + task.tool_name, + task.file_name, + e, + ) + with lock: + result.failed += 1 + report.failed_files.append( + { + "tool_id": task.tgt_tool_id, + "tool_name": task.tool_name, + "file_name": task.file_name, + "error": f"download: {e}", + } + ) + return + + mime = (payload or {}).get("mime_type") or "" + raw = self._decode_payload(payload, mime) + if raw is None: + logger.warning( + "files: unsupported mime tool=%s file=%s mime=%s", + task.tool_name, + task.file_name, + mime, + ) + with lock: + result.skipped += 1 + report.unsupported_files.append( + { + "tool_id": task.tgt_tool_id, + "tool_name": task.tool_name, + "file_name": task.file_name, + "mime_type": mime, + } + ) + return + + if len(raw) > self.ctx.options.max_file_size: + with lock: + result.skipped += 1 + report.oversize_files.append( + { + "tool_id": task.tgt_tool_id, + "tool_name": task.tool_name, + "file_name": task.file_name, + "size_bytes": len(raw), + "cap_bytes": self.ctx.options.max_file_size, + } + ) + logger.info( + "files: oversize tool=%s file=%s size=%d cap=%d", + task.tool_name, + task.file_name, + len(raw), + self.ctx.options.max_file_size, + ) + return + + try: + self._with_retry( + lambda: self.ctx.target.upload_prompt_file( + task.tgt_tool_id, task.file_name, raw, mime + ), + op=f"upload {task.tool_name}/{task.file_name}", + ) + except Exception as e: + logger.exception( + "files: upload failed tool=%s file=%s: %s", + task.tool_name, + task.file_name, + e, + ) + with lock: + result.failed += 1 + report.failed_files.append( + { + "tool_id": task.tgt_tool_id, + "tool_name": task.tool_name, + "file_name": task.file_name, + "error": f"upload: {e}", + } + ) + return + + with lock: + result.created += 1 + report.uploaded_files.append( + { + "tool_id": task.tgt_tool_id, + "tool_name": task.tool_name, + "file_name": task.file_name, + "size_bytes": len(raw), + "mime_type": mime, + } + ) + logger.info( + "files: uploaded tool=%s file=%s size=%d", + task.tool_name, + task.file_name, + len(raw), + ) + + def _emit_skip( + self, + src_docs: list[dict[str, Any]], + src_tool_id: str, + tgt_tool_id: str, + tool_name: str, + report: CloneReport, + result: PhaseResult, + ) -> None: + for doc in src_docs: + file_name = doc.get("document_name") + if not file_name: + continue + report.skipped_files.append( + { + "tool_id": tgt_tool_id, + "tool_name": tool_name, + "file_name": file_name, + "source_org_slug": self.ctx.source.endpoint.organization_id, + "source_tool_id": src_tool_id, + } + ) + result.skipped += 1 + logger.info( + "files: skip mode emitted %d filenames for tool=%s", + len(src_docs), + tool_name, + ) + + def _decode_payload( + self, payload: dict[str, Any] | None, mime: str + ) -> bytes | None: + if not payload: + return None + data_field = payload.get("data") + if data_field is None: + return None + if mime in _BASE64_MIMES: + if isinstance(data_field, bytes): + return base64.b64decode(data_field) + return base64.b64decode(data_field.encode()) + if mime in _TEXT_MIMES: + if isinstance(data_field, bytes): + return data_field + return data_field.encode("utf-8") + # Excel + unhandled types: BE returned a placeholder string, + # not real bytes. Round-trip would corrupt the file. + return None + + def _ensure_default_doc( + self, + src_tool_id: str, + tgt_tool_id: str, + tool_name: str, + src_docs: list[dict[str, Any]], + ) -> None: + """Set target ``CustomTool.output`` so the FE auto-selects a doc. + + Mirror source's chosen doc by filename when possible; fall back + to the first available target doc. Skip if target already has + ``output`` set — never override an operator's later choice on + re-runs. + """ + try: + tgt_tool = self.ctx.target.get_custom_tool(tgt_tool_id) + except Exception as e: + logger.warning( + "files: skipping default-doc set for tool=%s — fetch tgt failed: %s", + tool_name, + e, + ) + return + + if tgt_tool.get("output"): + logger.debug( + "files: target tool=%s already has default doc; leaving as-is", + tool_name, + ) + return + + try: + tgt_docs = self.ctx.target.list_prompt_documents(tgt_tool_id) + except Exception as e: + logger.warning( + "files: skipping default-doc set for tool=%s — list tgt docs failed: %s", + tool_name, + e, + ) + return + if not tgt_docs: + return + + chosen_id = self._pick_default_doc_id( + src_tool_id, src_docs, tgt_docs, tool_name + ) + if not chosen_id: + return + + try: + self.ctx.target.update_custom_tool(tgt_tool_id, {"output": chosen_id}) + logger.info( + "files: set default doc tool=%s doc_id=%s", tool_name, chosen_id + ) + except Exception as e: + logger.warning("files: PATCH default doc failed tool=%s: %s", tool_name, e) + + def _pick_default_doc_id( + self, + src_tool_id: str, + src_docs: list[dict[str, Any]], + tgt_docs: list[dict[str, Any]], + tool_name: str, + ) -> str | None: + try: + src_tool = self.ctx.source.get_custom_tool(src_tool_id) + src_output = src_tool.get("output") + except Exception as e: + logger.debug( + "files: source CustomTool fetch failed for tool=%s (%s); " + "falling back to first target doc", + tool_name, + e, + ) + src_output = None + + if src_output: + src_name = next( + ( + d.get("document_name") + for d in src_docs + if d.get("document_id") == src_output + ), + None, + ) + if src_name: + matched = next( + ( + d.get("document_id") + for d in tgt_docs + if d.get("document_name") == src_name + ), + None, + ) + if matched: + return matched + + return tgt_docs[0].get("document_id") + + def _lookup_tool_name(self, tgt_tool_id: str) -> str | None: + try: + tools = self.ctx.target.list_custom_tools() + except PlatformAPIError as e: + logger.warning( + "files: list_custom_tools failed during name lookup (%s); " + "log lines will fall back to tool ids", + e, + ) + return None + except (requests.ConnectionError, requests.Timeout) as e: + logger.warning( + "files: transport error during tool-name lookup (%s); " + "log lines will fall back to tool ids", + e, + ) + return None + for t in tools: + if t.get("tool_id") == tgt_tool_id: + return t.get("tool_name") + return None + + def _with_retry(self, fn: Any, *, op: str) -> Any: + last_exc: Exception | None = None + for attempt in range(1, _MAX_RETRIES + 1): + try: + return fn() + except PlatformAPIError as e: + last_exc = e + if e.status_code not in _RETRYABLE_STATUS or attempt == _MAX_RETRIES: + raise + sleep = _RETRY_BACKOFF_BASE_SECONDS * (2 ** (attempt - 1)) + logger.warning( + "files: retry %d/%d for %s after %d: sleeping %.1fs", + attempt, + _MAX_RETRIES, + op, + e.status_code, + sleep, + ) + time.sleep(sleep) + except (requests.ConnectionError, requests.Timeout) as e: + last_exc = e + if attempt == _MAX_RETRIES: + raise + sleep = _RETRY_BACKOFF_BASE_SECONDS * (2 ** (attempt - 1)) + logger.warning( + "files: retry %d/%d for %s after %s: sleeping %.1fs", + attempt, + _MAX_RETRIES, + op, + type(e).__name__, + sleep, + ) + time.sleep(sleep) + assert last_exc is not None + raise last_exc diff --git a/src/unstract/clone/phases/pipeline.py b/src/unstract/clone/phases/pipeline.py new file mode 100644 index 0000000..9892b1c --- /dev/null +++ b/src/unstract/clone/phases/pipeline.py @@ -0,0 +1,172 @@ +"""Migrate ETL/TASK pipelines from source org to target org. + +Pipelines FK ``workflow`` — the only entity remap needed. On create the +backend force-sets ``active=True`` and auto-provisions one active API +key per pipeline; if the source had additional rotated keys, those are +NOT mirrored (their UUIDs are server-generated and can't be preserved, +and operators rotate post-clone anyway). + +``DEFAULT`` (legacy) and ``APP`` pipeline types are skipped — DEFAULT is +dead code from the v1 era; APP is a Streamlit-style deployment whose +lifecycle isn't shaped like an ETL/TASK pipeline. +""" + +from __future__ import annotations + +import logging +import threading +from typing import Any + +from unstract.clone.exceptions import NameConflictError +from unstract.clone.phases.base import Phase, build_post_payload +from unstract.clone.report import CloneReport, PhaseResult +from unstract.clone.walker import remap_uuids + +logger = logging.getLogger(__name__) + +PIPELINE_PATH = "pipeline/" +_MIGRATABLE_TYPES: frozenset[str] = frozenset({"ETL", "TASK"}) + + +class PipelinePhase(Phase): + name = "pipeline" + + def run(self, report: CloneReport) -> PhaseResult: + result = report.get_phase(self.name) + try: + self._writable = self.ctx.target.get_post_schema(PIPELINE_PATH) + except Exception as e: + logger.exception("Failed to fetch target POST schema for pipeline: %s", e) + result.failed += 1 + result.errors.append(f"OPTIONS pipeline: {e}") + return result + + try: + src_pipelines = self.ctx.source.list_pipelines() + except Exception as e: + logger.exception("Failed to list source pipelines: %s", e) + result.failed += 1 + result.errors.append(f"list source pipelines: {e}") + return result + + migratable = [ + p for p in src_pipelines if p.get("pipeline_type") in _MIGRATABLE_TYPES + ] + skipped_types = len(src_pipelines) - len(migratable) + if skipped_types: + logger.info( + "Found %d source pipeline(s); skipping %d of unsupported type (DEFAULT/APP)", + len(src_pipelines), + skipped_types, + ) + else: + logger.info("Found %d source pipeline(s)", len(src_pipelines)) + + self.parallel_map( + migratable, + lambda src, lock: self._clone_one(src, result, lock), + ) + return result + + def _clone_one( + self, src: dict[str, Any], result: PhaseResult, lock: threading.Lock + ) -> None: + name = src["pipeline_name"] + src_id = src["id"] + src_wf_id = src.get("workflow") or src.get("workflow_id") + + if not src_wf_id: + logger.warning("source pipeline '%s' has no workflow FK — skipping", name) + with lock: + result.skipped += 1 + return + + with lock: + tgt_wf_id = self.ctx.remap.resolve("workflow", src_wf_id) + if not tgt_wf_id: + logger.warning( + "no workflow remap for pipeline '%s' (src workflow %s) — skipping", + name, + src_wf_id, + ) + with lock: + result.skipped += 1 + return + + try: + existing = self.ctx.target.list_pipelines(name=name) + except Exception as e: + logger.exception("Failed to GET pipeline %s on target: %s", name, e) + with lock: + result.failed += 1 + result.errors.append(f"GET {name}: {e}") + return + + if existing: + tgt = existing[0] + if self.ctx.options.on_name_conflict == "abort": + raise NameConflictError( + f"pipeline '{name}' already exists in target as {tgt['id']}" + ) + with lock: + result.adopted += 1 + logger.info( + "adopted pipeline '%s' src=%s -> tgt=%s", name, src_id, tgt["id"] + ) + elif self.ctx.options.dry_run: + with lock: + result.skipped += 1 + logger.info("[dry-run] would create pipeline '%s' src=%s", name, src_id) + return + else: + try: + full_src = self.ctx.source.get_pipeline(src_id) + except Exception as e: + logger.exception("Failed to GET source pipeline %s: %s", name, e) + with lock: + result.failed += 1 + result.errors.append(f"GET src pipeline {name}: {e}") + return + remapped = remap_uuids(full_src, self.ctx.remap) + payload = build_post_payload(remapped, self._writable) + payload["workflow"] = tgt_wf_id + try: + tgt = self.ctx.target.create_pipeline(payload) + except Exception as e: + logger.exception("Failed to create pipeline %s: %s", name, e) + with lock: + result.failed += 1 + result.errors.append(f"create {name}: {e}") + return + with lock: + result.created += 1 + logger.info( + "created pipeline '%s' src=%s -> tgt=%s", name, src_id, tgt["id"] + ) + self._warn_if_extra_source_keys(src_id, name) + + with lock: + self.ctx.remap.record("pipeline", src_id, tgt["id"]) + + def _warn_if_extra_source_keys(self, src_pipeline_id: str, name: str) -> None: + try: + keys = self.ctx.source.list_pipeline_keys(src_pipeline_id) + except Exception as e: + # WARNING (not DEBUG) — the operator needs to know we couldn't + # check whether they have additional keys to recreate manually. + logger.warning( + "Could not list source keys for pipeline %s " + "(extra-key check skipped; re-verify in source UI): %s", + name, + e, + ) + return + active = [k for k in keys if k.get("is_active")] + if len(active) > 1: + logger.warning( + "source pipeline '%s' had %d active API keys; " + "target has only the auto-provisioned default — " + "re-create the rest manually if your clients depend on them", + name, + len(active), + ) diff --git a/src/unstract/clone/phases/tag.py b/src/unstract/clone/phases/tag.py new file mode 100644 index 0000000..9cbca05 --- /dev/null +++ b/src/unstract/clone/phases/tag.py @@ -0,0 +1,97 @@ +"""Migrate tags from source org to target org. + +Tags are flat (``name`` + ``description``) with a per-org uniqueness +constraint on ``name``. No metadata, no encryption, no list-vs-detail +divergence — the simplest entity in the clone set. + +List endpoint paginates; ``PlatformClient.list_tags`` already unwraps +the envelope. +""" + +from __future__ import annotations + +import logging +import threading +from typing import Any + +from unstract.clone.exceptions import NameConflictError +from unstract.clone.phases.base import Phase, build_post_payload +from unstract.clone.report import CloneReport, PhaseResult + +logger = logging.getLogger(__name__) + +TAG_PATH = "tags/" + + +class TagPhase(Phase): + name = "tag" + + def run(self, report: CloneReport) -> PhaseResult: + result = report.get_phase(self.name) + try: + self._writable = self.ctx.target.get_post_schema(TAG_PATH) + except Exception as e: + logger.exception("Failed to fetch target POST schema for tag: %s", e) + result.failed += 1 + result.errors.append(f"OPTIONS tag: {e}") + return result + try: + src_tags = self.ctx.source.list_tags() + except Exception as e: + logger.exception("Failed to list source tags: %s", e) + result.failed += 1 + result.errors.append(f"list source tags: {e}") + return result + + logger.info("Found %d tag(s) in source org", len(src_tags)) + self.parallel_map( + src_tags, + lambda src, lock: self._clone_one(src, result, lock), + ) + return result + + def _clone_one( + self, src: dict[str, Any], result: PhaseResult, lock: threading.Lock + ) -> None: + name = src["name"] + src_id = src["id"] + + try: + existing = self.ctx.target.list_tags(name=name) + except Exception as e: + logger.exception("Failed to GET tag %s on target: %s", name, e) + with lock: + result.failed += 1 + result.errors.append(f"GET {name}: {e}") + return + + if existing: + tgt = existing[0] + if self.ctx.options.on_name_conflict == "abort": + raise NameConflictError( + f"tag '{name}' already exists in target as {tgt['id']}" + ) + with lock: + result.adopted += 1 + logger.info("adopted tag '%s' src=%s -> tgt=%s", name, src_id, tgt["id"]) + elif self.ctx.options.dry_run: + with lock: + result.skipped += 1 + logger.info("[dry-run] would create tag '%s' src=%s", name, src_id) + return + else: + payload = build_post_payload(src, self._writable) + try: + tgt = self.ctx.target.create_tag(payload) + except Exception as e: + logger.exception("Failed to create tag %s: %s", name, e) + with lock: + result.failed += 1 + result.errors.append(f"create {name}: {e}") + return + with lock: + result.created += 1 + logger.info("created tag '%s' src=%s -> tgt=%s", name, src_id, tgt["id"]) + + with lock: + self.ctx.remap.record("tag", src_id, tgt["id"]) diff --git a/src/unstract/clone/phases/tool_instance.py b/src/unstract/clone/phases/tool_instance.py new file mode 100644 index 0000000..293d206 --- /dev/null +++ b/src/unstract/clone/phases/tool_instance.py @@ -0,0 +1,231 @@ +"""Migrate ToolInstance rows from source org to target org. + +Each workflow holds at most one ToolInstance, enforced server-side +(``tool_instance_v2/serializers.py`` raises if a workflow already has one). +The row carries: + +- ``workflow`` FK — remapped from the WorkflowPhase remap table. +- ``tool_id`` (CharField, not FK) — a ``prompt_registry_id`` UUID. The + target's registry was rebuilt in CustomToolPhase, so we remap via the + ``prompt_studio_registry`` table populated there. +- ``metadata`` JSON — backend's ``create()`` discards the POST metadata + and rebuilds it from tool defaults. So we POST a bare instance, then + PATCH the metadata afterwards. Source metadata stores adapter values + as NAMES (via to_representation in source GET); on PATCH the backend's + ``update_metadata_with_adapter_instances`` resolves those names to + the target's adapter UUIDs. Names match across orgs because + AdapterPhase preserved them. +""" + +from __future__ import annotations + +import logging +import threading +from typing import Any + +from unstract.clone.phases.base import Phase +from unstract.clone.report import CloneReport, PhaseResult + +logger = logging.getLogger(__name__) + +# Source backend's ToolInstanceSerializer.to_representation emits these +# sentinel strings when an adapter UUID/name in the stored metadata can +# no longer be resolved (deleted or renamed on source). Round-tripping +# them to target produces an AdapterNotFound on PATCH, so we detect and +# skip the metadata PATCH instead — the ToolInstance row exists with the +# backend's safe defaults and the operator can re-bind in the UI. +_BROKEN_ADAPTER_SENTINELS: tuple[str, ...] = ( + "NOT FOUND", + "[DELETED ADAPTER", + "[NEEDS UPDATE]", +) + +# Fields tied to the source row's own ids — never valid on the target. +# Always rewrite these with target values before PATCHing. +_SOURCE_IDENTITY_FIELDS: tuple[str, ...] = ( + "prompt_registry_id", + "tool_instance_id", + "tenant_id", +) + + +def _broken_adapter_keys(metadata: dict[str, Any]) -> list[str]: + broken: list[str] = [] + for key, value in metadata.items(): + if isinstance(value, str) and any( + s in value for s in _BROKEN_ADAPTER_SENTINELS + ): + broken.append(f"{key}={value!r}") + return broken + + +def _strip_source_identity(metadata: dict[str, Any]) -> dict[str, Any]: + return {k: v for k, v in metadata.items() if k not in _SOURCE_IDENTITY_FIELDS} + + +class ToolInstancePhase(Phase): + name = "tool_instance" + + def run(self, report: CloneReport) -> PhaseResult: + result = report.get_phase(self.name) + workflow_remap = self.ctx.remap.snapshot().get("workflow", {}) + if not workflow_remap: + logger.info("No workflows in remap; nothing to do for tool_instance phase") + return result + + self.parallel_map( + list(workflow_remap.items()), + lambda pair, lock: self._clone_workflow_tools( + pair[0], pair[1], result, lock + ), + ) + return result + + def _clone_workflow_tools( + self, + src_wf_id: str, + tgt_wf_id: str, + result: PhaseResult, + lock: threading.Lock, + ) -> None: + try: + src_instances = self.ctx.source.list_tool_instances(workflow_id=src_wf_id) + except Exception as e: + logger.exception( + "Failed to list source tool_instances for wf %s: %s", src_wf_id, e + ) + with lock: + result.failed += 1 + result.errors.append(f"list src tool_instances {src_wf_id}: {e}") + return + + if not src_instances: + return + if len(src_instances) > 1: + logger.warning( + "source workflow %s has %d tool_instances (expected ≤1) — migrating first only", + src_wf_id, + len(src_instances), + ) + + src_ti = src_instances[0] + src_ti_id = src_ti["id"] + src_tool_id = src_ti["tool_id"] + + with lock: + tgt_tool_id = self.ctx.remap.resolve("prompt_studio_registry", src_tool_id) + if not tgt_tool_id: + logger.warning( + "skipping tool_instance %s — no registry remap for tool_id %s " + "(custom tool likely unpublished on source)", + src_ti_id, + src_tool_id, + ) + with lock: + result.skipped += 1 + return + + try: + existing = self.ctx.target.list_tool_instances(workflow_id=tgt_wf_id) + except Exception as e: + logger.exception( + "Failed to list target tool_instances for wf %s: %s", tgt_wf_id, e + ) + with lock: + result.failed += 1 + result.errors.append(f"list tgt tool_instances {tgt_wf_id}: {e}") + return + + if existing: + tgt_ti = existing[0] + if self.ctx.options.dry_run: + with lock: + result.skipped += 1 + self.ctx.remap.record("tool_instance", src_ti_id, tgt_ti["id"]) + logger.info( + "[dry-run] would re-PATCH metadata on adopted tool_instance " + "src=%s -> tgt=%s (workflow %s)", + src_ti_id, + tgt_ti["id"], + tgt_wf_id, + ) + return + with lock: + result.adopted += 1 + logger.info( + "adopted tool_instance src=%s -> tgt=%s (workflow %s)", + src_ti_id, + tgt_ti["id"], + tgt_wf_id, + ) + elif self.ctx.options.dry_run: + with lock: + result.skipped += 1 + logger.info( + "[dry-run] would create tool_instance for tgt workflow %s " + "(src tool_instance %s)", + tgt_wf_id, + src_ti_id, + ) + return + else: + try: + tgt_ti = self.ctx.target.create_tool_instance( + {"workflow_id": tgt_wf_id, "tool_id": tgt_tool_id} + ) + except Exception as e: + logger.exception( + "Failed to create tool_instance for wf %s: %s", tgt_wf_id, e + ) + with lock: + result.failed += 1 + result.errors.append(f"create tool_instance {tgt_wf_id}: {e}") + return + with lock: + result.created += 1 + logger.info( + "created tool_instance src=%s -> tgt=%s (workflow %s)", + src_ti_id, + tgt_ti["id"], + tgt_wf_id, + ) + + src_metadata = src_ti.get("metadata") or {} + broken = _broken_adapter_keys(src_metadata) + if broken: + logger.warning( + "skipping metadata PATCH for tool_instance src=%s tgt=%s — " + "source metadata carries broken adapter refs %s; " + "row exists with backend defaults, re-bind in UI", + src_ti_id, + tgt_ti["id"], + broken, + ) + with lock: + result.skipped += 1 + result.errors.append( + f"stale adapter refs on src tool_instance {src_ti_id}: {broken}" + ) + else: + # PATCH overwrites the whole metadata dict — re-stamp target + # identity fields or the runtime sees them as empty. + patch_metadata = { + **_strip_source_identity(src_metadata), + "prompt_registry_id": tgt_tool_id, + "tool_instance_id": tgt_ti["id"], + } + try: + self.ctx.target.update_tool_instance_metadata( + tgt_ti["id"], patch_metadata + ) + except Exception as e: + logger.exception( + "Failed to PATCH tool_instance %s metadata: %s", tgt_ti["id"], e + ) + with lock: + result.failed += 1 + result.errors.append(f"patch metadata {tgt_ti['id']}: {e}") + return + + with lock: + self.ctx.remap.record("tool_instance", src_ti_id, tgt_ti["id"]) diff --git a/src/unstract/clone/phases/workflow.py b/src/unstract/clone/phases/workflow.py new file mode 100644 index 0000000..36d55d4 --- /dev/null +++ b/src/unstract/clone/phases/workflow.py @@ -0,0 +1,109 @@ +"""Migrate workflows from source org to target org. + +Workflow rows themselves are simple — no required FKs to clone +entities, unique per ``(workflow_name, organization)``. The two +non-trivial bits: + +1. ``source_settings`` and ``destination_settings`` are JSON blobs that + embed connector UUIDs. The walker remaps them using the running + ``RemapTable`` (connectors already landed in the previous phase). + +2. Creating a workflow auto-creates empty ``WorkflowEndpoint`` rows + server-side. We don't touch those here — the dedicated + WorkflowEndpoint phase reconciles them after ToolInstance lands. +""" + +from __future__ import annotations + +import logging +import threading +from typing import Any + +from unstract.clone.exceptions import NameConflictError +from unstract.clone.phases.base import Phase, build_post_payload +from unstract.clone.report import CloneReport, PhaseResult +from unstract.clone.walker import remap_uuids + +logger = logging.getLogger(__name__) + +WORKFLOW_PATH = "workflow/" + + +class WorkflowPhase(Phase): + name = "workflow" + + def run(self, report: CloneReport) -> PhaseResult: + result = report.get_phase(self.name) + try: + self._writable = self.ctx.target.get_post_schema(WORKFLOW_PATH) + except Exception as e: + logger.exception("Failed to fetch target POST schema for workflow: %s", e) + result.failed += 1 + result.errors.append(f"OPTIONS workflow: {e}") + return result + + try: + src_workflows = self.ctx.source.list_workflows() + except Exception as e: + logger.exception("Failed to list source workflows: %s", e) + result.failed += 1 + result.errors.append(f"list source workflows: {e}") + return result + + logger.info("Found %d workflow(s) in source org", len(src_workflows)) + self.parallel_map( + src_workflows, + lambda src, lock: self._clone_one(src, result, lock), + ) + return result + + def _clone_one( + self, src: dict[str, Any], result: PhaseResult, lock: threading.Lock + ) -> None: + name = src["workflow_name"] + src_id = src["id"] + + try: + existing = self.ctx.target.list_workflows(name=name) + except Exception as e: + logger.exception("Failed to GET workflow %s on target: %s", name, e) + with lock: + result.failed += 1 + result.errors.append(f"GET {name}: {e}") + return + + if existing: + tgt = existing[0] + if self.ctx.options.on_name_conflict == "abort": + raise NameConflictError( + f"workflow '{name}' already exists in target as {tgt['id']}" + ) + with lock: + result.adopted += 1 + logger.info( + "adopted workflow '%s' src=%s -> tgt=%s", name, src_id, tgt["id"] + ) + elif self.ctx.options.dry_run: + with lock: + result.skipped += 1 + logger.info("[dry-run] would create workflow '%s' src=%s", name, src_id) + return + else: + remapped = remap_uuids(src, self.ctx.remap) + payload = build_post_payload(remapped, self._writable) + try: + tgt = self.ctx.target.create_workflow(payload) + except Exception as e: + logger.exception("Failed to create workflow %s: %s", name, e) + with lock: + result.failed += 1 + result.errors.append(f"create {name}: {e}") + return + with lock: + result.created += 1 + logger.info( + "created workflow '%s' src=%s -> tgt=%s", name, src_id, tgt["id"] + ) + + with lock: + self.ctx.remap.record("workflow", src_id, tgt["id"]) diff --git a/src/unstract/clone/phases/workflow_endpoint.py b/src/unstract/clone/phases/workflow_endpoint.py new file mode 100644 index 0000000..a9ffa7a --- /dev/null +++ b/src/unstract/clone/phases/workflow_endpoint.py @@ -0,0 +1,190 @@ +"""Migrate WorkflowEndpoint rows from source org to target org. + +The backend auto-creates one SOURCE and one DESTINATION endpoint per +workflow on workflow create (``perform_create`` in WorkflowViewSet), so +there's nothing to POST — we only PATCH the target's existing endpoints +with the source's connection_type, connector_instance, and configuration. + +Notes: +- ``workflow`` and ``endpoint_type`` are ``editable=False`` server-side + and aren't writable on PATCH. +- ``connector_instance`` FK is nullable; we remap via the connector + remap table populated in ConnectorPhase. +- ``configuration`` is a JSON blob that may embed connector UUIDs; + walker pass remaps them before PATCH. +- Source ``connector_instance`` arrives as a nested dict (per + ``WorkflowEndpointSerializer.connector_instance``); we extract its + ``id`` and remap. +""" + +from __future__ import annotations + +import logging +import threading +from typing import Any + +from unstract.clone.phases.base import Phase +from unstract.clone.report import CloneReport, PhaseResult +from unstract.clone.walker import remap_uuids + +logger = logging.getLogger(__name__) + + +def _extract_connector_id(endpoint: dict[str, Any]) -> str | None: + """``connector_instance`` is a nested dict on GET; pull out the FK uuid.""" + ci = endpoint.get("connector_instance") + if isinstance(ci, dict): + return ci.get("id") + if isinstance(ci, str): + return ci + return None + + +class WorkflowEndpointPhase(Phase): + name = "workflow_endpoint" + + def run(self, report: CloneReport) -> PhaseResult: + result = report.get_phase(self.name) + workflow_remap = self.ctx.remap.snapshot().get("workflow", {}) + if not workflow_remap: + logger.info( + "No workflows in remap; nothing to do for workflow_endpoint phase" + ) + return result + + self.parallel_map( + list(workflow_remap.items()), + lambda pair, lock: self._clone_workflow_endpoints( + pair[0], pair[1], result, lock + ), + ) + return result + + def _clone_workflow_endpoints( + self, + src_wf_id: str, + tgt_wf_id: str, + result: PhaseResult, + lock: threading.Lock, + ) -> None: + try: + src_endpoints = self.ctx.source.list_workflow_endpoints( + workflow_id=src_wf_id + ) + except Exception as e: + logger.exception( + "Failed to list source endpoints for wf %s: %s", src_wf_id, e + ) + with lock: + result.failed += 1 + result.errors.append(f"list src endpoints {src_wf_id}: {e}") + return + + try: + tgt_endpoints = self.ctx.target.list_workflow_endpoints( + workflow_id=tgt_wf_id + ) + except Exception as e: + logger.exception( + "Failed to list target endpoints for wf %s: %s", tgt_wf_id, e + ) + with lock: + result.failed += 1 + result.errors.append(f"list tgt endpoints {tgt_wf_id}: {e}") + return + + tgt_by_type = {ep["endpoint_type"]: ep for ep in tgt_endpoints} + + for src_ep in src_endpoints: + etype = src_ep["endpoint_type"] + tgt_ep = tgt_by_type.get(etype) + if tgt_ep is None: + logger.warning( + "target workflow %s missing %s endpoint — skipping", + tgt_wf_id, + etype, + ) + with lock: + result.failed += 1 + result.errors.append( + f"missing tgt {etype} endpoint for wf {tgt_wf_id}" + ) + continue + + self._patch_endpoint(src_ep, tgt_ep, result, lock) + + def _patch_endpoint( + self, + src_ep: dict[str, Any], + tgt_ep: dict[str, Any], + result: PhaseResult, + lock: threading.Lock, + ) -> None: + src_ep_id = src_ep["id"] + tgt_ep_id = tgt_ep["id"] + etype = src_ep["endpoint_type"] + + if self.ctx.options.dry_run: + with lock: + result.skipped += 1 + logger.info( + "[dry-run] would PATCH %s endpoint src=%s -> tgt=%s", + etype, + src_ep_id, + tgt_ep_id, + ) + return + + src_conn_id = _extract_connector_id(src_ep) + tgt_conn_id: str | None = None + if src_conn_id: + with lock: + tgt_conn_id = self.ctx.remap.resolve("connector", src_conn_id) + if not tgt_conn_id: + logger.warning( + "skipping %s endpoint src=%s tgt=%s — source connector %s " + "has no target remap; would silently unset connector", + etype, + src_ep_id, + tgt_ep_id, + src_conn_id, + ) + with lock: + result.skipped += 1 + result.errors.append( + f"unmapped connector on {etype} endpoint {src_ep_id}: " + f"src_connector={src_conn_id}" + ) + return + + payload: dict[str, Any] = { + "configuration": remap_uuids( + src_ep.get("configuration") or {}, self.ctx.remap + ), + "connector_instance_id": tgt_conn_id, + } + src_connection_type = src_ep.get("connection_type") + if src_connection_type is not None: + payload["connection_type"] = src_connection_type + + try: + self.ctx.target.update_workflow_endpoint(tgt_ep_id, payload) + except Exception as e: + logger.exception( + "Failed to PATCH %s endpoint tgt=%s: %s", etype, tgt_ep_id, e + ) + with lock: + result.failed += 1 + result.errors.append(f"patch {etype} {tgt_ep_id}: {e}") + return + + with lock: + result.created += 1 + self.ctx.remap.record("workflow_endpoint", src_ep_id, tgt_ep_id) + logger.info( + "patched %s endpoint src=%s -> tgt=%s (connector %s)", + etype, + src_ep_id, + tgt_ep_id, + tgt_conn_id, + ) diff --git a/src/unstract/clone/report.py b/src/unstract/clone/report.py new file mode 100644 index 0000000..296b8b6 --- /dev/null +++ b/src/unstract/clone/report.py @@ -0,0 +1,296 @@ +"""Structured report produced by ``clone()``. + +Tracks per-phase counts (created / adopted / skipped / failed) and a final +remap snapshot. Renders to a rich-formatted table when ``rich`` is +available; falls back to plain text otherwise. +""" + +from __future__ import annotations + +import logging +from dataclasses import dataclass, field +from typing import Any + +logger = logging.getLogger(__name__) + + +@dataclass +class PhaseResult: + name: str + created: int = 0 + adopted: int = 0 + skipped: int = 0 + failed: int = 0 + errors: list[str] = field(default_factory=list) + duration_s: float = 0.0 + + +@dataclass +class Endpoint: + """Just enough about an endpoint for the report header — never carries the API key.""" + + base_url: str + organization_id: str + + +@dataclass +class CloneReport: + source: Endpoint | None = None + target: Endpoint | None = None + phases: list[PhaseResult] = field(default_factory=list) + skipped_phases: list[str] = field(default_factory=list) + remap_snapshot: dict[str, dict[str, str]] = field(default_factory=dict) + aborted: bool = False + abort_reason: str | None = None + total_duration_s: float = 0.0 + # Files-phase artifacts. Each entry carries enough context for an + # operator to act on it without cross-referencing the run log. + uploaded_files: list[dict[str, Any]] = field(default_factory=list) + skipped_files: list[dict[str, Any]] = field(default_factory=list) + oversize_files: list[dict[str, Any]] = field(default_factory=list) + unsupported_files: list[dict[str, Any]] = field(default_factory=list) + failed_files: list[dict[str, Any]] = field(default_factory=list) + + def get_phase(self, name: str) -> PhaseResult: + for p in self.phases: + if p.name == name: + return p + result = PhaseResult(name=name) + self.phases.append(result) + return result + + def render(self) -> str: + """Render as a rich table when available, otherwise plain text.""" + try: + from io import StringIO + + from rich.console import Console + from rich.table import Table + except ImportError: + return self._render_plain() + + buf = StringIO() + # force_terminal so ANSI codes survive the StringIO capture; the + # caller decides whether to strip them when printing to a non-tty. + console = Console( + file=buf, force_terminal=True, color_system="truecolor", width=100 + ) + self._render_endpoints(console.print) + table = Table(title="Clone Report", header_style="bold cyan") + table.add_column("Phase", style="bold", justify="left") + for col in ("Created", "Adopted", "Skipped", "Failed", "Time"): + table.add_column(col, justify="right") + + totals = {"created": 0, "adopted": 0, "skipped": 0, "failed": 0} + for p in self.phases: + phase_style = "red" if p.failed else ("yellow" if p.skipped else "green") + table.add_row( + f"[{phase_style}]{p.name}[/{phase_style}]", + self._fmt_count(p.created, "green"), + self._fmt_count(p.adopted, "green"), + self._fmt_count(p.skipped, "yellow"), + self._fmt_count(p.failed, "red"), + self._fmt_duration(p.duration_s), + ) + for k in totals: + totals[k] += getattr(p, k) + + table.add_section() + table.add_row( + "[bold]TOTAL[/bold]", + self._fmt_count(totals["created"], "green", bold=True), + self._fmt_count(totals["adopted"], "green", bold=True), + self._fmt_count(totals["skipped"], "yellow", bold=True), + self._fmt_count(totals["failed"], "red", bold=True), + self._fmt_duration(self.total_duration_s, bold=True), + ) + console.print(table) + if self.skipped_phases: + console.print( + f"[dim]Skipped phases:[/dim] {', '.join(self.skipped_phases)}" + ) + self._render_files_sections(console) + self._render_remap_summary(console_print=console.print) + if self.aborted: + console.print(f"[bold red]ABORTED:[/bold red] {self.abort_reason}") + elif totals["failed"]: + console.print( + f"[bold red]Completed with {totals['failed']} failure(s)[/bold red] — " + "see WARNING/ERROR log lines above for details" + ) + else: + console.print("[bold green]Completed successfully[/bold green]") + return buf.getvalue() + + @staticmethod + def _fmt_count(value: int, color: str, bold: bool = False) -> str: + """Dim a zero to keep the eye on non-zero cells; colour anything > 0.""" + if value == 0: + return "[dim]0[/dim]" + style = f"bold {color}" if bold else color + return f"[{style}]{value}[/{style}]" + + @staticmethod + def _fmt_duration(seconds: float, bold: bool = False) -> str: + if seconds <= 0: + return "[dim]—[/dim]" + if seconds < 60: + text = f"{seconds:.1f}s" + else: + mins, secs = divmod(seconds, 60) + text = f"{int(mins)}m{secs:.0f}s" + return f"[bold]{text}[/bold]" if bold else text + + @staticmethod + def _fmt_duration_plain(seconds: float) -> str: + if seconds <= 0: + return "—" + if seconds < 60: + return f"{seconds:.1f}s" + mins, secs = divmod(seconds, 60) + return f"{int(mins)}m{secs:.0f}s" + + def _render_plain(self) -> str: + lines = ["Clone Report", "=" * 60] + self._render_endpoints(lines.append) + header = ( + f"{'Phase':<24}{'Created':>10}{'Adopted':>10}" + f"{'Skipped':>10}{'Failed':>10}{'Time':>10}" + ) + lines.append(header) + for p in self.phases: + lines.append( + f"{p.name:<24}{p.created:>10}{p.adopted:>10}" + f"{p.skipped:>10}{p.failed:>10}{self._fmt_duration_plain(p.duration_s):>10}" + ) + lines.append( + f"{'TOTAL':<64}{self._fmt_duration_plain(self.total_duration_s):>10}" + ) + if self.skipped_phases: + lines.append(f"Skipped phases: {', '.join(self.skipped_phases)}") + lines.extend(self._files_sections_plain()) + self._render_remap_summary(console_print=lines.append) + if self.aborted: + lines.append(f"ABORTED: {self.abort_reason}") + return "\n".join(lines) + + def as_dict(self) -> dict[str, Any]: + return { + "source": ( + { + "base_url": self.source.base_url, + "organization_id": self.source.organization_id, + } + if self.source + else None + ), + "target": ( + { + "base_url": self.target.base_url, + "organization_id": self.target.organization_id, + } + if self.target + else None + ), + "phases": [ + { + "name": p.name, + "created": p.created, + "adopted": p.adopted, + "skipped": p.skipped, + "failed": p.failed, + "errors": list(p.errors), + "duration_s": p.duration_s, + } + for p in self.phases + ], + "skipped_phases": list(self.skipped_phases), + "remap_snapshot": self.remap_snapshot, + "aborted": self.aborted, + "abort_reason": self.abort_reason, + "total_duration_s": self.total_duration_s, + "uploaded_files": list(self.uploaded_files), + "skipped_files": list(self.skipped_files), + "oversize_files": list(self.oversize_files), + "unsupported_files": list(self.unsupported_files), + "failed_files": list(self.failed_files), + } + + def _render_endpoints(self, console_print: Any) -> None: + if not self.source and not self.target: + return + src = self._fmt_endpoint(self.source) + tgt = self._fmt_endpoint(self.target) + console_print(f"Source: {src}") + console_print(f"Target: {tgt}") + + @staticmethod + def _fmt_endpoint(ep: Endpoint | None) -> str: + if ep is None: + return "?" + return f"{ep.organization_id} @ {ep.base_url}" + + def _render_remap_summary(self, console_print: Any) -> None: + """Summarise the remap snapshot. Full map is large and noisy, so + we only print per-entity counts here; the full mapping is emitted + at DEBUG and remains in ``as_dict()`` for programmatic consumers. + """ + if not self.remap_snapshot: + return + counts = ", ".join( + f"{entity}={len(mapping)}" + for entity, mapping in self.remap_snapshot.items() + if mapping + ) + if counts: + console_print(f"Remap entries: {counts}") + if logger.isEnabledFor(logging.DEBUG): + for entity, mapping in self.remap_snapshot.items(): + for src, tgt in mapping.items(): + logger.debug("remap %s %s -> %s", entity, src, tgt) + + def _render_files_sections(self, console: Any) -> None: + if self.uploaded_files: + console.print(f"[green]Files uploaded:[/green] {len(self.uploaded_files)}") + for header, rows in ( + ("Oversize files (manual upload required)", self.oversize_files), + ("Unsupported mime files (manual upload required)", self.unsupported_files), + ("Skipped files (operator action required)", self.skipped_files), + ("Failed files", self.failed_files), + ): + if not rows: + continue + console.print(f"[yellow]{header}:[/yellow]") + for row in rows: + console.print(f" - {self._describe_file_row(row)}") + + def _files_sections_plain(self) -> list[str]: + lines: list[str] = [] + if self.uploaded_files: + lines.append(f"Files uploaded: {len(self.uploaded_files)}") + for header, rows in ( + ("Oversize files (manual upload required)", self.oversize_files), + ("Unsupported mime files (manual upload required)", self.unsupported_files), + ("Skipped files (operator action required)", self.skipped_files), + ("Failed files", self.failed_files), + ): + if not rows: + continue + lines.append(f"{header}:") + for row in rows: + lines.append(f" - {self._describe_file_row(row)}") + return lines + + @staticmethod + def _describe_file_row(row: dict[str, Any]) -> str: + tool = row.get("tool_name") or row.get("tool_id") or "?" + name = row.get("file_name", "?") + extras: list[str] = [] + if "size_bytes" in row: + extras.append(f"{row['size_bytes']} bytes") + if "mime_type" in row: + extras.append(row["mime_type"]) + if "error" in row: + extras.append(f"error={row['error']}") + suffix = f" ({', '.join(extras)})" if extras else "" + return f"tool={tool} file={name}{suffix}" diff --git a/src/unstract/clone/walker.py b/src/unstract/clone/walker.py new file mode 100644 index 0000000..eb9c401 --- /dev/null +++ b/src/unstract/clone/walker.py @@ -0,0 +1,32 @@ +"""JSON walker that rewrites embedded source UUIDs to target UUIDs. + +Used by phases whose payloads carry foreign-key UUIDs inside JSON fields +(e.g. ``tool_instance.metadata``). Unknown UUIDs pass through untouched so +we don't accidentally rewrite an unrelated identifier that just happens +to look like a UUID. +""" + +from __future__ import annotations + +import re +from typing import Any + +from unstract.clone.context import RemapTable + +UUID_RE = re.compile( + r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$", + re.IGNORECASE, +) + + +def remap_uuids(obj: Any, remap: RemapTable) -> Any: + """Walk a JSON-shaped value; replace any string that looks like a UUID + AND has a known mapping. Unknown UUIDs pass through untouched. + """ + if isinstance(obj, dict): + return {k: remap_uuids(v, remap) for k, v in obj.items()} + if isinstance(obj, list): + return [remap_uuids(v, remap) for v in obj] + if isinstance(obj, str) and UUID_RE.match(obj): + return remap.resolve_any(obj) or obj + return obj diff --git a/tests/clone/__init__.py b/tests/clone/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/clone/test_adapter_phase.py b/tests/clone/test_adapter_phase.py new file mode 100644 index 0000000..d2b0311 --- /dev/null +++ b/tests/clone/test_adapter_phase.py @@ -0,0 +1,163 @@ +"""Tests for ``AdapterPhase``. + +Uses an in-process fake ``PlatformClient`` to avoid real HTTP. Verifies: +- happy path: source has N adapters, target gets N POSTs, all remapped +- idempotency: re-run with target already populated → zero POSTs, all adopted +- dry-run: zero POSTs, all skipped +- on_name_conflict='abort' raises on existing +""" + +from __future__ import annotations + +import pytest + +from unstract.clone.context import ( + CloneContext, + CloneOptions, + RemapTable, +) +from unstract.clone.exceptions import NameConflictError +from unstract.clone.phases.adapter import AdapterPhase +from unstract.clone.report import CloneReport + + +class FakeClient: + """Minimal in-memory stand-in for ``PlatformClient``.""" + + # Mirrors DRF OPTIONS actions.POST writable fields for adapter. + POST_SCHEMA = frozenset( + {"adapter_id", "adapter_name", "adapter_type", "adapter_metadata", "description"} + ) + + def __init__(self, adapters: list[dict] | None = None): + # Stored as a list of dicts; mutated by create_adapter. + self.adapters: list[dict] = list(adapters or []) + self.posts: list[dict] = [] + self._next_id = 1 + + def get_post_schema(self, entity_path): + return self.POST_SCHEMA + + def list_adapters(self, *, name=None, adapter_type=None): + result = self.adapters + if name is not None: + result = [a for a in result if a["adapter_name"] == name] + if adapter_type is not None: + result = [a for a in result if a["adapter_type"] == adapter_type] + # Mimic AdapterListSerializer — strip adapter_metadata from list output. + return [{k: v for k, v in a.items() if k != "adapter_metadata"} for a in result] + + def get_adapter(self, adapter_pk): + for a in self.adapters: + if a["id"] == adapter_pk: + return a + raise KeyError(adapter_pk) + + def create_adapter(self, payload): + new = dict(payload) + new["id"] = f"tgt-{self._next_id:08d}-0000-0000-0000-000000000000" + self._next_id += 1 + self.adapters.append(new) + self.posts.append(new) + return new + + +def _src_adapter(id_, name, atype="LLM"): + return { + "id": id_, + "adapter_id": "openai-llm-v2", + "adapter_name": name, + "adapter_type": atype, + "adapter_metadata": {"api_key": "sk-secret", "model": "gpt-4"}, + "description": f"{name} desc", + } + + +def _ctx(source: FakeClient, target: FakeClient, **opt_overrides): + ctx = CloneContext( + source=source, + target=target, + options=CloneOptions(**opt_overrides), + remap=RemapTable(), + ) + return ctx + + +def test_happy_path_creates_all_and_records_remap(): + src = FakeClient( + [ + _src_adapter("src-a", "OpenAI Prod"), + _src_adapter("src-b", "Mistral Stg", atype="EMBEDDING"), + ] + ) + tgt = FakeClient() + ctx = _ctx(src, tgt) + report = CloneReport() + + result = AdapterPhase(ctx).run(report) + + assert result.created == 2 + assert result.adopted == 0 + assert result.failed == 0 + assert len(tgt.posts) == 2 + assert ctx.remap.resolve("adapter", "src-a") == tgt.posts[0]["id"] + assert ctx.remap.resolve("adapter", "src-b") == tgt.posts[1]["id"] + + +def test_idempotency_zero_creates_on_rerun(): + src_adapters = [_src_adapter("src-a", "OpenAI Prod")] + src = FakeClient(src_adapters) + # Target pre-populated with the same name+type — simulates a prior run. + tgt = FakeClient( + [ + { + "id": "preexisting", + "adapter_id": "openai-llm-v2", + "adapter_name": "OpenAI Prod", + "adapter_type": "LLM", + "adapter_metadata": {}, + } + ] + ) + ctx = _ctx(src, tgt, on_name_conflict="adopt") + report = CloneReport() + + result = AdapterPhase(ctx).run(report) + + assert result.created == 0 + assert result.adopted == 1 + assert tgt.posts == [] # no new POSTs + assert ctx.remap.resolve("adapter", "src-a") == "preexisting" + + +def test_dry_run_makes_no_posts(): + src = FakeClient([_src_adapter("src-a", "OpenAI Prod")]) + tgt = FakeClient() + ctx = _ctx(src, tgt, dry_run=True) + report = CloneReport() + + result = AdapterPhase(ctx).run(report) + + assert result.skipped == 1 + assert result.created == 0 + assert tgt.posts == [] + + +def test_abort_on_name_conflict_raises(): + src = FakeClient([_src_adapter("src-a", "OpenAI Prod")]) + tgt = FakeClient( + [ + { + "id": "preexisting", + "adapter_id": "openai-llm-v2", + "adapter_name": "OpenAI Prod", + "adapter_type": "LLM", + "adapter_metadata": {}, + } + ] + ) + ctx = _ctx(src, tgt, on_name_conflict="abort") + report = CloneReport() + + with pytest.raises(NameConflictError): + AdapterPhase(ctx).run(report) diff --git a/tests/clone/test_api_deployment_phase.py b/tests/clone/test_api_deployment_phase.py new file mode 100644 index 0000000..dc25d7a --- /dev/null +++ b/tests/clone/test_api_deployment_phase.py @@ -0,0 +1,185 @@ +"""Tests for ``APIDeploymentPhase``. + +Coverage: +- happy path: source api_deployments created with workflow FK remapped. +- adopt by api_name on existing target deployment. +- skipped when workflow remap missing. +- dry-run is a no-op. +- abort raises ``NameConflictError``. +- extra source keys produce a warning, never a failure. +""" + +from __future__ import annotations + +import logging + +import pytest + +from unstract.clone.context import ( + CloneContext, + CloneOptions, + RemapTable, +) +from unstract.clone.exceptions import NameConflictError +from unstract.clone.phases.api_deployment import APIDeploymentPhase +from unstract.clone.report import CloneReport + +API_DEPLOYMENT_POST_SCHEMA = frozenset( + { + "display_name", + "description", + "workflow", + "is_active", + "api_name", + "shared_users", + "shared_to_org", + } +) + + +class FakeClient: + def __init__(self, deployments: list[dict] | None = None): + self.deployments: list[dict] = list(deployments or []) + self.posts: list[dict] = [] + self.keys_by_deployment: dict[str, list[dict]] = {} + self._next = 1 + + def get_post_schema(self, entity_path: str) -> frozenset[str]: + return API_DEPLOYMENT_POST_SCHEMA + + def list_api_deployments(self, *, api_name: str | None = None): + result = self.deployments + if api_name is not None: + result = [d for d in result if d["api_name"] == api_name] + return list(result) + + def get_api_deployment(self, deployment_id: str) -> dict: + for d in self.deployments: + if d["id"] == deployment_id: + return dict(d) + raise KeyError(deployment_id) + + def create_api_deployment(self, payload: dict) -> dict: + new = dict(payload) + new["id"] = f"tgt-dep-{self._next:04d}" + new["api_key"] = f"key-{self._next:04d}" + self._next += 1 + self.deployments.append(new) + self.posts.append(new) + return new + + def list_api_deployment_keys(self, deployment_id: str) -> list[dict]: + return list(self.keys_by_deployment.get(deployment_id, [])) + + +def _src_deployment( + id_: str, api_name: str, workflow_id: str, *, display_name: str | None = None +) -> dict: + return { + "id": id_, + "api_name": api_name, + "display_name": display_name or api_name, + "description": f"{api_name} desc", + "workflow": workflow_id, + "workflow_id": workflow_id, + "is_active": True, + "shared_users": [], + "shared_to_org": False, + } + + +def _ctx(source, target, *, remap=None, **opt_overrides): + return CloneContext( + source=source, + target=target, + options=CloneOptions(**opt_overrides), + remap=remap or RemapTable(), + ) + + +def test_happy_path_creates_deployment_with_remapped_workflow(): + src = FakeClient([_src_deployment("src-dep-1", "invoices_api", "wf-src-1")]) + tgt = FakeClient() + remap = RemapTable() + remap.record("workflow", "wf-src-1", "wf-tgt-1") + ctx = _ctx(src, tgt, remap=remap) + + result = APIDeploymentPhase(ctx).run(CloneReport()) + + assert result.created == 1 + assert result.failed == 0 + posted = tgt.posts[0] + assert posted["api_name"] == "invoices_api" + assert posted["workflow"] == "wf-tgt-1" + assert ctx.remap.resolve("api_deployment", "src-dep-1") == posted["id"] + + +def test_adopts_existing_deployment_by_api_name(): + src = FakeClient([_src_deployment("src-dep-1", "invoices_api", "wf-src-1")]) + tgt = FakeClient([{"id": "tgt-existing", "api_name": "invoices_api"}]) + remap = RemapTable() + remap.record("workflow", "wf-src-1", "wf-tgt-1") + ctx = _ctx(src, tgt, remap=remap) + + result = APIDeploymentPhase(ctx).run(CloneReport()) + + assert result.adopted == 1 + assert result.created == 0 + assert tgt.posts == [] + assert ctx.remap.resolve("api_deployment", "src-dep-1") == "tgt-existing" + + +def test_skipped_when_workflow_remap_missing(): + src = FakeClient([_src_deployment("src-dep-1", "orphan", "wf-src-1")]) + tgt = FakeClient() + ctx = _ctx(src, tgt) # No workflow remap. + + result = APIDeploymentPhase(ctx).run(CloneReport()) + + assert result.skipped == 1 + assert tgt.posts == [] + + +def test_dry_run_makes_no_writes(): + src = FakeClient([_src_deployment("src-dep-1", "invoices_api", "wf-src-1")]) + tgt = FakeClient() + remap = RemapTable() + remap.record("workflow", "wf-src-1", "wf-tgt-1") + ctx = _ctx(src, tgt, remap=remap, dry_run=True) + + result = APIDeploymentPhase(ctx).run(CloneReport()) + + assert result.skipped == 1 + assert tgt.posts == [] + + +def test_abort_on_name_conflict_raises(): + src = FakeClient([_src_deployment("src-dep-1", "invoices_api", "wf-src-1")]) + tgt = FakeClient([{"id": "tgt-existing", "api_name": "invoices_api"}]) + remap = RemapTable() + remap.record("workflow", "wf-src-1", "wf-tgt-1") + ctx = _ctx(src, tgt, remap=remap, on_name_conflict="abort") + + with pytest.raises(NameConflictError): + APIDeploymentPhase(ctx).run(CloneReport()) + + +def test_extra_source_keys_log_warning_not_failure(caplog): + src = FakeClient([_src_deployment("src-dep-1", "invoices_api", "wf-src-1")]) + src.keys_by_deployment["src-dep-1"] = [ + {"id": "k1", "is_active": True}, + {"id": "k2", "is_active": True}, + ] + tgt = FakeClient() + remap = RemapTable() + remap.record("workflow", "wf-src-1", "wf-tgt-1") + ctx = _ctx(src, tgt, remap=remap) + + with caplog.at_level( + logging.WARNING, logger="unstract.clone.phases.api_deployment" + ): + result = APIDeploymentPhase(ctx).run(CloneReport()) + + assert result.created == 1 + assert result.failed == 0 + assert any("2 active API keys" in r.message for r in caplog.records) diff --git a/tests/clone/test_base_helpers.py b/tests/clone/test_base_helpers.py new file mode 100644 index 0000000..727affa --- /dev/null +++ b/tests/clone/test_base_helpers.py @@ -0,0 +1,58 @@ +"""Tests for ``unstract.clone.phases.base`` helpers.""" + +from __future__ import annotations + +from unstract.clone.phases.base import SERVER_MANAGED, build_post_payload + + +def test_preserves_false_and_zero_values(): + """Booleans set to False and numeric 0 are legitimate field values. + + Earlier ``value not in (None, "")`` worked for None/"" but dropped + False and 0 too because of Python's ``False == 0 == in (None, "")`` + edge case. Regression guard. + """ + src = { + "is_active": False, + "retry_count": 0, + "rate_limit": 0.0, + "name": "demo", + } + writable = frozenset({"is_active", "retry_count", "rate_limit", "name"}) + + payload = build_post_payload(src, writable) + + assert payload == { + "is_active": False, + "retry_count": 0, + "rate_limit": 0.0, + "name": "demo", + } + + +def test_strips_none_and_empty_string_but_keeps_zero(): + src = {"a": None, "b": "", "c": 0, "d": False, "e": "kept"} + writable = frozenset({"a", "b", "c", "d", "e"}) + + payload = build_post_payload(src, writable) + + assert payload == {"c": 0, "d": False, "e": "kept"} + + +def test_drops_server_managed_keys_even_if_writable(): + src = {"id": "X", "name": "demo", "organization": "org", "created_by": "u"} + # All four are nominally writable but SERVER_MANAGED should win. + writable = frozenset(src.keys()) + + payload = build_post_payload(src, writable) + + assert payload == {"name": "demo"} + for key in SERVER_MANAGED & set(src.keys()): + assert key not in payload + + +def test_ignores_writable_keys_missing_from_src(): + src = {"present": 1} + writable = frozenset({"present", "absent"}) + + assert build_post_payload(src, writable) == {"present": 1} diff --git a/tests/clone/test_cli.py b/tests/clone/test_cli.py new file mode 100644 index 0000000..e4ac623 --- /dev/null +++ b/tests/clone/test_cli.py @@ -0,0 +1,129 @@ +"""Tests for the click CLI wiring in ``unstract.clone.cli``. + +Coverage: +- ``_parse_size`` accepts bare integers, K/M/G suffixes, decimals. +- ``--max-file-size 0`` propagates as 0 (force every file to manual list), + not the default cap — distinguished from the unparseable case. +""" + +from __future__ import annotations + +import pytest +from click.testing import CliRunner + +from unstract.clone.cli import _parse_size, cli +from unstract.clone.context import DEFAULT_MAX_FILE_SIZE, CloneOptions +from unstract.clone.report import CloneReport, Endpoint + + +def test_parse_size_bare_int_is_bytes(): + assert _parse_size("25") == 25 + + +def test_parse_size_accepts_kb_mb_gb_units(): + assert _parse_size("25MB") == 25 * 1024 * 1024 + assert _parse_size("1.5GB") == int(1.5 * 1024 * 1024 * 1024) + assert _parse_size("512K") == 512 * 1024 + + +def test_parse_size_zero_returns_zero(): + # Regression for `cap_bytes or DEFAULT` — must not coerce 0 to the + # default. CLI flag --max-file-size 0 means "every file goes to the + # oversize/manual-upload list". + assert _parse_size("0") == 0 + + +def test_parse_size_unknown_unit_raises(): + import click + + with pytest.raises(click.BadParameter): + _parse_size("10XB") + + +def test_parse_size_unparseable_raises(): + import click + + with pytest.raises(click.BadParameter): + _parse_size("not-a-size") + + +def test_cli_max_file_size_zero_propagates_to_options(monkeypatch): + captured: dict = {} + + def fake_clone(source, target, options=None): + captured["options"] = options + return CloneReport( + source=Endpoint( + base_url=source.base_url, organization_id=source.organization_id + ), + target=Endpoint( + base_url=target.base_url, organization_id=target.organization_id + ), + ) + + monkeypatch.setattr("unstract.clone.cli.run_clone", fake_clone) + + result = CliRunner().invoke( + cli, + [ + "clone", + "--source-url", + "http://src", + "--source-org", + "src", + "--source-key", + "sk", + "--target-url", + "http://tgt", + "--target-org", + "tgt", + "--target-key", + "tk", + "--max-file-size", + "0", + ], + ) + + assert result.exit_code == 0, result.output + opts: CloneOptions = captured["options"] + assert opts.max_file_size == 0 + + +def test_cli_max_file_size_default_when_flag_omitted(monkeypatch): + captured: dict = {} + + def fake_clone(source, target, options=None): + captured["options"] = options + return CloneReport( + source=Endpoint( + base_url=source.base_url, organization_id=source.organization_id + ), + target=Endpoint( + base_url=target.base_url, organization_id=target.organization_id + ), + ) + + monkeypatch.setattr("unstract.clone.cli.run_clone", fake_clone) + + result = CliRunner().invoke( + cli, + [ + "clone", + "--source-url", + "http://src", + "--source-org", + "src", + "--source-key", + "sk", + "--target-url", + "http://tgt", + "--target-org", + "tgt", + "--target-key", + "tk", + ], + ) + + assert result.exit_code == 0, result.output + opts: CloneOptions = captured["options"] + assert opts.max_file_size == DEFAULT_MAX_FILE_SIZE diff --git a/tests/clone/test_client.py b/tests/clone/test_client.py new file mode 100644 index 0000000..9fa3dca --- /dev/null +++ b/tests/clone/test_client.py @@ -0,0 +1,145 @@ +"""Tests for ``PlatformClient`` HTTP layer. + +Coverage: +- URL composition honours base_url, api_path_prefix, organization_id. +- Bearer auth header present on every request. +- Non-2xx response raises ``PlatformAPIError`` with status_code + body. +- 204 / empty body returns ``None`` instead of raising on .json(). +- ``get_post_schema`` parses DRF ``actions.POST`` and caches per path. +- ``close()`` shuts the underlying session; context manager works. +""" + +from __future__ import annotations + +from unittest.mock import MagicMock + +import pytest + +from unstract.clone.client import PlatformClient +from unstract.clone.context import OrgEndpoint +from unstract.clone.exceptions import PlatformAPIError + + +def _endpoint() -> OrgEndpoint: + return OrgEndpoint( + base_url="https://api.example.com", + organization_id="org_abc", + platform_key="plat-key-xyz", + ) + + +def _fake_response(status: int, payload=None, text: str = "") -> MagicMock: + resp = MagicMock() + resp.status_code = status + resp.text = text + resp.content = b"" if payload is None and not text else b"x" + resp.json.return_value = payload + return resp + + +def _client_with_mock( + payload=None, status: int = 200, text: str = "" +) -> tuple[PlatformClient, MagicMock]: + client = PlatformClient(_endpoint()) + mock_request = MagicMock(return_value=_fake_response(status, payload, text)) + client._session.request = mock_request + return client, mock_request + + +def test_url_composition_includes_org_and_api_prefix(): + client, mock_request = _client_with_mock(payload=[]) + client.list_adapters() + call = mock_request.call_args + assert call.args[0] == "GET" + assert call.args[1] == "https://api.example.com/api/v1/unstract/org_abc/adapter/" + + +def test_bearer_token_sent_on_session(): + client, _ = _client_with_mock(payload=[]) + assert client._session.headers["Authorization"] == "Bearer plat-key-xyz" + assert client._session.headers["Accept"] == "application/json" + + +def test_non_2xx_raises_platform_api_error_with_status_and_body(): + client, _ = _client_with_mock(status=404, text="not found") + with pytest.raises(PlatformAPIError) as exc_info: + client.list_adapters() + err = exc_info.value + assert err.status_code == 404 + assert "not found" in err.body + + +def test_500_with_long_body_truncated_to_2000_chars(): + big = "x" * 5000 + client, _ = _client_with_mock(status=500, text=big) + with pytest.raises(PlatformAPIError) as exc_info: + client.list_adapters() + assert len(exc_info.value.body) == 2000 + + +def test_204_no_content_returns_none(): + client = PlatformClient(_endpoint()) + resp = MagicMock() + resp.status_code = 204 + resp.content = b"" + client._session.request = MagicMock(return_value=resp) + assert client._request("DELETE", "tag/abc/") is None + + +def test_get_post_schema_parses_options_and_caches(): + options_body = { + "actions": { + "POST": { + "name": {"read_only": False}, + "id": {"read_only": True}, + "shared_to_org": {"read_only": False}, + # No read_only key → treated as writable. + "description": {}, + } + } + } + client, mock_request = _client_with_mock(payload=options_body) + writable = client.get_post_schema("adapter/") + assert writable == frozenset({"name", "shared_to_org", "description"}) + # second call hits cache — no extra HTTP. + writable2 = client.get_post_schema("adapter/") + assert writable2 is writable + assert mock_request.call_count == 1 + + +def test_get_post_schema_handles_missing_actions_block(): + client, _ = _client_with_mock(payload={}) + assert client.get_post_schema("connector/") == frozenset() + + +def test_close_shuts_session(): + client = PlatformClient(_endpoint()) + sess = client._session + sess.close = MagicMock() + client.close() + sess.close.assert_called_once() + + +def test_context_manager_closes_on_exit(): + with PlatformClient(_endpoint()) as client: + client._session.close = MagicMock() + sess_close = client._session.close + sess_close.assert_called_once() + + +def test_list_endpoint_unwraps_paginated_envelope(): + client, _ = _client_with_mock(payload={"results": [{"id": "a"}, {"id": "b"}]}) + items = client.list_tags() + assert [i["id"] for i in items] == ["a", "b"] + + +def test_list_endpoint_accepts_bare_list(): + client, _ = _client_with_mock(payload=[{"id": "a"}]) + items = client.list_tags() + assert items == [{"id": "a"}] + + +def test_options_response_with_null_body_still_yields_empty_schema(): + # Some deployments return 200 with no body on OPTIONS. + client, _ = _client_with_mock(payload=None, text="") + assert client.get_post_schema("pipeline/") == frozenset() diff --git a/tests/clone/test_connector_phase.py b/tests/clone/test_connector_phase.py new file mode 100644 index 0000000..4a3c413 --- /dev/null +++ b/tests/clone/test_connector_phase.py @@ -0,0 +1,175 @@ +"""Tests for ``ConnectorPhase``. + +Mirrors the adapter phase suite — happy path, idempotency, dry-run, +abort — plus connector-specific behavior: UCS auto-provisioned rows are +skipped without consulting the target. +""" + +from __future__ import annotations + +import pytest + +from unstract.clone.context import ( + CloneContext, + CloneOptions, + RemapTable, +) +from unstract.clone.exceptions import NameConflictError +from unstract.clone.phases.connector import ConnectorPhase +from unstract.clone.report import CloneReport + + +class FakeClient: + POST_SCHEMA = frozenset( + { + "connector_id", + "connector_name", + "connector_metadata", + "connector_version", + "connector_mode", + "connector_type", + "shared_to_org", + } + ) + + def __init__(self, connectors: list[dict] | None = None): + self.connectors: list[dict] = list(connectors or []) + self.posts: list[dict] = [] + self._next_id = 1 + + def get_post_schema(self, entity_path): + return self.POST_SCHEMA + + def list_connectors(self, *, name=None, connector_type=None): + result = self.connectors + if name is not None: + result = [c for c in result if c["connector_name"] == name] + if connector_type is not None: + result = [c for c in result if c.get("connector_type") == connector_type] + return list(result) + + def get_connector(self, connector_pk): + for c in self.connectors: + if c["id"] == connector_pk: + return c + raise KeyError(connector_pk) + + def create_connector(self, payload): + new = dict(payload) + new["id"] = f"tgt-{self._next_id:08d}-0000-0000-0000-000000000000" + self._next_id += 1 + self.connectors.append(new) + self.posts.append(new) + return new + + +def _src(id_, name, catalog_id="postgres|abc", ctype="INPUT"): + return { + "id": id_, + "connector_id": catalog_id, + "connector_name": name, + "connector_type": ctype, + "connector_version": "1.0", + "connector_metadata": {"host": "db.example.com", "password": "secret"}, + "shared_to_org": False, + } + + +def _ctx(source, target, **opt_overrides): + return CloneContext( + source=source, + target=target, + options=CloneOptions(**opt_overrides), + remap=RemapTable(), + ) + + +def test_happy_path_creates_all_and_records_remap(): + src = FakeClient([_src("src-a", "Prod PG"), _src("src-b", "Stg S3", "s3|xyz")]) + tgt = FakeClient() + ctx = _ctx(src, tgt) + report = CloneReport() + + result = ConnectorPhase(ctx).run(report) + + assert result.created == 2 + assert result.adopted == 0 + assert result.skipped == 0 + assert len(tgt.posts) == 2 + assert ctx.remap.resolve("connector", "src-a") == tgt.posts[0]["id"] + assert ctx.remap.resolve("connector", "src-b") == tgt.posts[1]["id"] + + +def test_redacted_metadata_connector_skipped(): + """Source returning empty metadata (redacted by backend) is unmigratable — + skipped with no POST and no remap entry.""" + redacted = _src("src-ucs", "User Storage") + redacted["connector_metadata"] = {} # backend redaction signal + src = FakeClient([redacted]) + tgt = FakeClient() + ctx = _ctx(src, tgt) + report = CloneReport() + + result = ConnectorPhase(ctx).run(report) + + assert result.skipped == 1 + assert result.created == 0 + assert tgt.posts == [] + assert ctx.remap.resolve("connector", "src-ucs") is None + + +def test_idempotency_zero_creates_on_rerun(): + src = FakeClient([_src("src-a", "Prod PG")]) + tgt = FakeClient( + [ + { + "id": "preexisting", + "connector_id": "postgres|abc", + "connector_name": "Prod PG", + "connector_type": "INPUT", + "connector_metadata": {}, + } + ] + ) + ctx = _ctx(src, tgt, on_name_conflict="adopt") + report = CloneReport() + + result = ConnectorPhase(ctx).run(report) + + assert result.created == 0 + assert result.adopted == 1 + assert tgt.posts == [] + assert ctx.remap.resolve("connector", "src-a") == "preexisting" + + +def test_dry_run_makes_no_posts(): + src = FakeClient([_src("src-a", "Prod PG")]) + tgt = FakeClient() + ctx = _ctx(src, tgt, dry_run=True) + report = CloneReport() + + result = ConnectorPhase(ctx).run(report) + + assert result.skipped == 1 + assert result.created == 0 + assert tgt.posts == [] + + +def test_abort_on_name_conflict_raises(): + src = FakeClient([_src("src-a", "Prod PG")]) + tgt = FakeClient( + [ + { + "id": "preexisting", + "connector_id": "postgres|abc", + "connector_name": "Prod PG", + "connector_type": "INPUT", + "connector_metadata": {}, + } + ] + ) + ctx = _ctx(src, tgt, on_name_conflict="abort") + report = CloneReport() + + with pytest.raises(NameConflictError): + ConnectorPhase(ctx).run(report) diff --git a/tests/clone/test_custom_tool_phase.py b/tests/clone/test_custom_tool_phase.py new file mode 100644 index 0000000..dac4a13 --- /dev/null +++ b/tests/clone/test_custom_tool_phase.py @@ -0,0 +1,349 @@ +"""Tests for ``CustomToolPhase`` — project-transfer + sync-prompts based. + +Coverage: +- fresh path: ``export_project`` on source → ``import_project`` on + target with adapter ids resolved by looking up each source-profile + adapter NAME against the target via ``list_adapters(name=...)``. +- adopt path: existing target tool with matching name → + ``sync_prompts`` overwrites prompts; no profile/adapter writes. +- registry remap recorded after ``export_custom_tool``. +- dry-run: no writes on either side. +- abort on name conflict when option is set. +- missing target adapter fails the tool cleanly. +""" + +from __future__ import annotations + +import pytest + +from unstract.clone.context import ( + CloneContext, + CloneOptions, + RemapTable, +) +from unstract.clone.exceptions import NameConflictError +from unstract.clone.phases.custom_tool import CustomToolPhase +from unstract.clone.report import CloneReport + +ADAPTER_NAMES = { + "llm": "gpt4", + "embedding_model": "ada-embed", + "vector_store": "pgvector", + "x2text": "llmw", +} +TGT_ADAPTER_IDS = { + "gpt4": "a1111111-1111-1111-1111-111111111111", + "ada-embed": "a2222222-2222-2222-2222-222222222222", + "pgvector": "a3333333-3333-3333-3333-333333333333", + "llmw": "a4444444-4444-4444-4444-444444444444", +} +SRC_REG = "55555555-5555-5555-5555-555555555555" + + +class FakeClient: + """In-memory stand-in for ``PlatformClient`` covering project-transfer.""" + + def __init__(self) -> None: + self.tools: dict[str, dict] = {} + self.profiles_by_tool: dict[str, list[dict]] = {} + self.export_blobs: dict[str, dict] = {} + self.registries_by_tool: dict[str, dict] = {} + self.adapters_by_name: dict[str, dict] = {} + # Call recorders. + self.import_calls: list[tuple[dict, dict | None]] = [] + self.sync_calls: list[tuple[str, dict, bool]] = [] + self.export_tool_calls: list[str] = [] + self._next = 1 + + def _mint(self, prefix: str) -> str: + s = f"tgt-{prefix}-{self._next:04d}" + self._next += 1 + return s + + # --- reads --- + def list_custom_tools(self) -> list[dict]: + return [ + {"tool_id": tid, "tool_name": t["tool_name"]} + for tid, t in self.tools.items() + ] + + def list_profiles(self, tool_id: str) -> list[dict]: + return list(self.profiles_by_tool.get(tool_id, [])) + + def export_project(self, tool_id: str) -> dict: + return self.export_blobs[tool_id] + + def list_adapters( + self, + *, + name: str | None = None, + adapter_type: str | None = None, + ) -> list[dict]: + if name is None: + return list(self.adapters_by_name.values()) + ad = self.adapters_by_name.get(name) + return [ad] if ad else [] + + def list_registries(self, *, custom_tool: str | None = None) -> list[dict]: + if custom_tool is None: + return list(self.registries_by_tool.values()) + reg = self.registries_by_tool.get(custom_tool) + return [reg] if reg else [] + + # --- writes --- + def import_project( + self, export_data: dict, adapter_ids: dict | None = None + ) -> dict: + self.import_calls.append((export_data, adapter_ids)) + tool_id = self._mint("tool") + tool_name = export_data["tool_metadata"]["tool_name"] + self.tools[tool_id] = {"tool_name": tool_name} + return { + "tool_id": tool_id, + "message": f"Project imported successfully as '{tool_name}'", + "needs_adapter_config": adapter_ids is None, + } + + def sync_prompts( + self, tool_id: str, export_data: dict, *, create_copy: bool = False + ) -> dict: + self.sync_calls.append((tool_id, export_data, create_copy)) + return { + "prompts_created": len(export_data.get("prompts", [])), + "prompts_deleted": 0, + "tool_settings_updated": True, + } + + def export_custom_tool(self, tool_id: str, *, force: bool = True) -> None: + self.export_tool_calls.append(tool_id) + self.registries_by_tool.setdefault( + tool_id, + {"prompt_registry_id": self._mint("registry"), "custom_tool": tool_id}, + ) + + +def _ctx(source, target, *, remap=None, **opt_overrides) -> CloneContext: + return CloneContext( + source=source, + target=target, + options=CloneOptions(**opt_overrides), + remap=remap or RemapTable(), + ) + + +def _seed_target_adapters(target: FakeClient) -> None: + """ProfileManagerSerializer surfaces adapter NAMES — target must + expose name → id lookups for the phase to resolve them. + """ + for name, adapter_id in TGT_ADAPTER_IDS.items(): + target.adapters_by_name[name] = {"id": adapter_id, "adapter_name": name} + + +def _src_default_profile(*, nested: bool = False) -> dict: + """Mirror the live ProfileManager serializer: adapter FKs render as + flat NAME strings. ``nested=True`` covers the alternate dict shape + in case backend behavior changes. + """ + if nested: + return { + "profile_id": "src-profile-1", + "profile_name": "Default", + "is_default": True, + "llm": {"adapter_name": ADAPTER_NAMES["llm"]}, + "embedding_model": {"adapter_name": ADAPTER_NAMES["embedding_model"]}, + "vector_store": {"adapter_name": ADAPTER_NAMES["vector_store"]}, + "x2text": {"adapter_name": ADAPTER_NAMES["x2text"]}, + } + return { + "profile_id": "src-profile-1", + "profile_name": "Default", + "is_default": True, + "llm": ADAPTER_NAMES["llm"], + "embedding_model": ADAPTER_NAMES["embedding_model"], + "vector_store": ADAPTER_NAMES["vector_store"], + "x2text": ADAPTER_NAMES["x2text"], + } + + +def _src_export_blob(tool_name: str) -> dict: + return { + "tool_metadata": { + "tool_name": tool_name, + "description": "x", + "author": "a", + "icon": None, + }, + "tool_settings": {"preamble": "p", "postamble": "q"}, + "default_profile_settings": { + "chunk_size": 1024, + "chunk_overlap": 128, + "retrieval_strategy": "simple", + "similarity_top_k": 3, + "section": "default", + "profile_name": "Default", + }, + "prompts": [ + { + "prompt_key": "field_a", + "prompt": "What is field_a?", + "sequence_number": 1, + } + ], + "export_metadata": {"exported_at": "2026-05-24T00:00:00Z"}, + } + + +def _preload_source_tool( + client: FakeClient, tool_id: str, tool_name: str, *, nested_profile: bool = False +) -> None: + client.tools[tool_id] = {"tool_name": tool_name} + client.profiles_by_tool[tool_id] = [_src_default_profile(nested=nested_profile)] + client.export_blobs[tool_id] = _src_export_blob(tool_name) + client.registries_by_tool[tool_id] = { + "prompt_registry_id": SRC_REG, + "custom_tool": tool_id, + } + + +def test_fresh_imports_with_name_resolved_adapter_ids_and_records_registry(): + src = FakeClient() + tgt = FakeClient() + _preload_source_tool(src, "src-tool-x", "Invoice Extractor") + _seed_target_adapters(tgt) + ctx = _ctx(src, tgt) + + result = CustomToolPhase(ctx).run(CloneReport()) + + assert result.created == 1 + assert result.failed == 0 + # Exactly one import_project call with the right export blob + name-resolved adapter ids. + assert len(tgt.import_calls) == 1 + blob, adapter_ids = tgt.import_calls[0] + assert blob["tool_metadata"]["tool_name"] == "Invoice Extractor" + assert adapter_ids == { + "llm_adapter_id": TGT_ADAPTER_IDS["gpt4"], + "vector_db_adapter_id": TGT_ADAPTER_IDS["pgvector"], + "embedding_adapter_id": TGT_ADAPTER_IDS["ada-embed"], + "x2text_adapter_id": TGT_ADAPTER_IDS["llmw"], + } + # No sync_prompts on fresh path. + assert tgt.sync_calls == [] + # Registry republish fired exactly once. + assert len(tgt.export_tool_calls) == 1 + tgt_tool_id = tgt.export_tool_calls[0] + + # Remap records populated for downstream phases. + assert ctx.remap.resolve("custom_tool", "src-tool-x") == tgt_tool_id + tgt_reg_id = tgt.registries_by_tool[tgt_tool_id]["prompt_registry_id"] + assert ctx.remap.resolve("prompt_studio_registry", SRC_REG) == tgt_reg_id + + +def test_nested_adapter_dict_also_resolves(): + src = FakeClient() + tgt = FakeClient() + _preload_source_tool(src, "src-tool-x", "T", nested_profile=True) + _seed_target_adapters(tgt) + ctx = _ctx(src, tgt) + + CustomToolPhase(ctx).run(CloneReport()) + + _, adapter_ids = tgt.import_calls[0] + assert adapter_ids["llm_adapter_id"] == TGT_ADAPTER_IDS["gpt4"] + + +def test_adopt_path_calls_sync_prompts_and_skips_import(): + src = FakeClient() + tgt = FakeClient() + _preload_source_tool(src, "src-tool-x", "Invoice Extractor") + # Target already has the tool with the same name. + tgt.tools["tgt-existing"] = {"tool_name": "Invoice Extractor"} + + _seed_target_adapters(tgt) + ctx = _ctx(src, tgt) + + result = CustomToolPhase(ctx).run(CloneReport()) + + assert result.adopted == 1 + assert result.created == 0 + # sync_prompts ran against the pre-existing target tool, not a new one. + assert len(tgt.sync_calls) == 1 + tool_id, blob, create_copy = tgt.sync_calls[0] + assert tool_id == "tgt-existing" + assert blob["tool_metadata"]["tool_name"] == "Invoice Extractor" + assert create_copy is False + # Import path never fired on adopt. + assert tgt.import_calls == [] + # Registry still republished against the adopted tool. + assert tgt.export_tool_calls == ["tgt-existing"] + assert ctx.remap.resolve("custom_tool", "src-tool-x") == "tgt-existing" + + +def test_abort_on_name_conflict_raises(): + src = FakeClient() + tgt = FakeClient() + _preload_source_tool(src, "src-tool-x", "Conflict") + tgt.tools["tgt-existing"] = {"tool_name": "Conflict"} + + _seed_target_adapters(tgt) + ctx = _ctx(src, tgt, on_name_conflict="abort") + + with pytest.raises(NameConflictError): + CustomToolPhase(ctx).run(CloneReport()) + + assert tgt.sync_calls == [] + assert tgt.import_calls == [] + + +def test_dry_run_makes_no_writes(): + src = FakeClient() + tgt = FakeClient() + _preload_source_tool(src, "src-tool-x", "T") + _seed_target_adapters(tgt) + ctx = _ctx(src, tgt, dry_run=True) + + result = CustomToolPhase(ctx).run(CloneReport()) + + assert result.skipped == 1 + assert tgt.import_calls == [] + assert tgt.sync_calls == [] + assert tgt.export_tool_calls == [] + + +def test_dry_run_on_adopt_path_does_not_republish_registry(): + # Adopt path used to return tgt_tool_id even on dry-run, falling + # through to export_custom_tool (a real POST to the target). + src = FakeClient() + tgt = FakeClient() + _preload_source_tool(src, "src-tool-x", "Invoice Extractor") + tgt.tools["tgt-existing"] = {"tool_name": "Invoice Extractor"} + _seed_target_adapters(tgt) + ctx = _ctx(src, tgt, dry_run=True) + + result = CustomToolPhase(ctx).run(CloneReport()) + + assert result.skipped == 1 + assert tgt.sync_calls == [] + assert tgt.import_calls == [] + # Critical regression: registry republish must NOT fire on dry-run. + assert tgt.export_tool_calls == [] + # Remap still recorded so downstream dry-run output stays coherent. + assert ctx.remap.resolve("custom_tool", "src-tool-x") == "tgt-existing" + + +def test_missing_target_adapter_fails_tool_cleanly(): + src = FakeClient() + tgt = FakeClient() + _preload_source_tool(src, "src-tool-x", "T") + # Only seed 3 of 4 adapters → x2text lookup misses on target. + for name in ("gpt4", "ada-embed", "pgvector"): + tgt.adapters_by_name[name] = {"id": TGT_ADAPTER_IDS[name], "adapter_name": name} + ctx = _ctx(src, tgt) + + result = CustomToolPhase(ctx).run(CloneReport()) + + assert result.failed == 1 + assert tgt.import_calls == [] + # Registry republish should NOT fire when the tool fails. + assert tgt.export_tool_calls == [] + # No custom_tool remap recorded. + assert ctx.remap.resolve("custom_tool", "src-tool-x") is None diff --git a/tests/clone/test_files_phase.py b/tests/clone/test_files_phase.py new file mode 100644 index 0000000..50f739e --- /dev/null +++ b/tests/clone/test_files_phase.py @@ -0,0 +1,525 @@ +"""Tests for ``FilesPhase``. + +Coverage: +- happy path: PDF + text/csv files uploaded with base64 + utf-8 decoding. +- target-side idempotency: filename already present → skip, no upload. +- oversize file → ``oversize_files`` entry, sibling files continue. +- unsupported mime (Excel placeholder) → ``unsupported_files`` entry. +- skip strategy → no uploads, source filenames listed in ``skipped_files``. +- dry-run → no uploads even for missing files. +- transient 503 → retried, eventual success. +- no custom_tool remap → no-op. +- listing failure on source aborts only that tool, others continue. +""" + +from __future__ import annotations + +import base64 +from typing import Any + +import pytest + +from unstract.clone.context import ( + CloneContext, + CloneOptions, + OrgEndpoint, + RemapTable, +) +from unstract.clone.exceptions import PlatformAPIError +from unstract.clone.phases.files import FilesPhase +from unstract.clone.report import CloneReport + +SRC_ENDPOINT = OrgEndpoint( + base_url="http://src", organization_id="src-org", platform_key="src-key" +) +TGT_ENDPOINT = OrgEndpoint( + base_url="http://tgt", organization_id="tgt-org", platform_key="tgt-key" +) + + +class FakeClient: + def __init__( + self, + *, + endpoint: OrgEndpoint, + documents: dict[str, list[dict]] | None = None, + file_payloads: dict[tuple[str, str], dict] | None = None, + tools: list[dict] | None = None, + ): + self.endpoint = endpoint + # tool_id -> list of {document_name, document_id, tool} + self._documents: dict[str, list[dict]] = { + k: list(v) for k, v in (documents or {}).items() + } + # (tool_id, file_name) -> {"data": ..., "mime_type": ...} + self._file_payloads: dict[tuple[str, str], dict] = dict(file_payloads or {}) + self._tools = list(tools or []) + self.uploaded: list[dict[str, Any]] = [] + self.list_calls: list[str] = [] + self.download_calls: list[tuple[str, str]] = [] + # Configurable fault injection. + self.download_errors: dict[tuple[str, str], list[Exception]] = {} + self.upload_errors: dict[tuple[str, str], list[Exception]] = {} + self.list_errors: dict[str, Exception] = {} + self._next_id = 1 + + def list_prompt_documents(self, tool_id: str) -> list[dict]: + self.list_calls.append(tool_id) + if tool_id in self.list_errors: + raise self.list_errors[tool_id] + return [dict(d) for d in self._documents.get(tool_id, [])] + + def download_prompt_file(self, tool_id: str, document_id: str) -> dict: + # Tests key payloads + error queues by (tool_id, file_name) for + # readability; resolve the filename from the documents list. + file_name = next( + ( + d["document_name"] + for d in self._documents.get(tool_id, []) + if d.get("document_id") == document_id + ), + document_id, + ) + self.download_calls.append((tool_id, file_name)) + queue = self.download_errors.get((tool_id, file_name)) + if queue: + raise queue.pop(0) + return dict(self._file_payloads[(tool_id, file_name)]) + + def upload_prompt_file( + self, tool_id: str, file_name: str, data: bytes, mime_type: str + ) -> dict: + queue = self.upload_errors.get((tool_id, file_name)) + if queue: + raise queue.pop(0) + doc_id = f"doc-{self._next_id:04d}" + self._next_id += 1 + self.uploaded.append( + { + "tool_id": tool_id, + "file_name": file_name, + "data": data, + "mime_type": mime_type, + } + ) + self._documents.setdefault(tool_id, []).append( + {"document_id": doc_id, "document_name": file_name, "tool": tool_id} + ) + return {"document_id": doc_id} + + def list_custom_tools(self) -> list[dict]: + return list(self._tools) + + def get_custom_tool(self, tool_id: str) -> dict: + return dict(next((t for t in self._tools if t.get("tool_id") == tool_id), {})) + + def update_custom_tool(self, tool_id: str, body: dict) -> dict: + for t in self._tools: + if t.get("tool_id") == tool_id: + t.update(body) + return dict(t) + return {} + + +def _ctx( + src: FakeClient, tgt: FakeClient, *, remap: RemapTable | None = None, **opts +) -> CloneContext: + remap = remap or RemapTable() + return CloneContext( + source=src, + target=tgt, + options=CloneOptions(**opts), + remap=remap, + ) + + +def _doc(name: str) -> dict: + return {"document_id": f"src-{name}", "document_name": name, "tool": "ignored"} + + +def _pdf_payload(raw: bytes) -> dict: + return {"data": base64.b64encode(raw).decode(), "mime_type": "application/pdf"} + + +def _text_payload(text: str, mime: str = "text/plain") -> dict: + return {"data": text, "mime_type": mime} + + +def test_happy_path_uploads_pdf_and_text(): + src = FakeClient( + endpoint=SRC_ENDPOINT, + documents={"src-1": [_doc("invoice.pdf"), _doc("notes.txt")]}, + file_payloads={ + ("src-1", "invoice.pdf"): _pdf_payload(b"%PDF-FAKE"), + ("src-1", "notes.txt"): _text_payload("hello world"), + }, + ) + tgt = FakeClient( + endpoint=TGT_ENDPOINT, tools=[{"tool_id": "tgt-1", "tool_name": "demo"}] + ) + remap = RemapTable() + remap.record("custom_tool", "src-1", "tgt-1") + ctx = _ctx(src, tgt, remap=remap) + report = CloneReport() + + result = FilesPhase(ctx).run(report) + + assert result.created == 2 + assert result.failed == 0 + assert {u["file_name"] for u in tgt.uploaded} == {"invoice.pdf", "notes.txt"} + pdf_upload = next(u for u in tgt.uploaded if u["file_name"] == "invoice.pdf") + assert pdf_upload["data"] == b"%PDF-FAKE" + assert pdf_upload["mime_type"] == "application/pdf" + txt_upload = next(u for u in tgt.uploaded if u["file_name"] == "notes.txt") + assert txt_upload["data"] == b"hello world" + assert len(report.uploaded_files) == 2 + assert all(u["tool_name"] == "demo" for u in report.uploaded_files) + + +def test_target_filename_present_is_skipped_no_download(): + src = FakeClient( + endpoint=SRC_ENDPOINT, + documents={"src-1": [_doc("invoice.pdf")]}, + file_payloads={("src-1", "invoice.pdf"): _pdf_payload(b"BYTES")}, + ) + tgt = FakeClient( + endpoint=TGT_ENDPOINT, + documents={"tgt-1": [_doc("invoice.pdf")]}, + tools=[{"tool_id": "tgt-1", "tool_name": "demo"}], + ) + remap = RemapTable() + remap.record("custom_tool", "src-1", "tgt-1") + ctx = _ctx(src, tgt, remap=remap) + report = CloneReport() + + result = FilesPhase(ctx).run(report) + + assert result.skipped == 1 + assert result.created == 0 + assert tgt.uploaded == [] + assert src.download_calls == [] # pre-check guards the download + + +def test_oversize_file_is_recorded_and_siblings_continue(): + big = b"X" * 50 + src = FakeClient( + endpoint=SRC_ENDPOINT, + documents={"src-1": [_doc("big.pdf"), _doc("small.txt")]}, + file_payloads={ + ("src-1", "big.pdf"): _pdf_payload(big), + ("src-1", "small.txt"): _text_payload("ok"), + }, + ) + tgt = FakeClient( + endpoint=TGT_ENDPOINT, tools=[{"tool_id": "tgt-1", "tool_name": "demo"}] + ) + remap = RemapTable() + remap.record("custom_tool", "src-1", "tgt-1") + ctx = _ctx(src, tgt, remap=remap, max_file_size=10) + report = CloneReport() + + result = FilesPhase(ctx).run(report) + + assert result.created == 1 + # Oversize must bump skipped so the operator sees it surfaced in the + # phase counters, not only in the report's list. + assert result.skipped == 1 + assert result.failed == 0 + assert {u["file_name"] for u in tgt.uploaded} == {"small.txt"} + assert len(report.oversize_files) == 1 + over = report.oversize_files[0] + assert over["file_name"] == "big.pdf" + assert over["size_bytes"] == 50 + assert over["cap_bytes"] == 10 + + +def test_unsupported_mime_is_recorded_not_uploaded(): + src = FakeClient( + endpoint=SRC_ENDPOINT, + documents={"src-1": [_doc("sheet.xlsx")]}, + file_payloads={ + ("src-1", "sheet.xlsx"): { + "data": "Preview not available for Excel files. ...", + "mime_type": "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", + } + }, + ) + tgt = FakeClient( + endpoint=TGT_ENDPOINT, tools=[{"tool_id": "tgt-1", "tool_name": "demo"}] + ) + remap = RemapTable() + remap.record("custom_tool", "src-1", "tgt-1") + ctx = _ctx(src, tgt, remap=remap) + report = CloneReport() + + result = FilesPhase(ctx).run(report) + + assert result.created == 0 + # Unsupported mimes must bump skipped so the run doesn't report green + # while leaving files unmoved. + assert result.skipped == 1 + assert result.failed == 0 + assert tgt.uploaded == [] + assert len(report.unsupported_files) == 1 + entry = report.unsupported_files[0] + assert entry["file_name"] == "sheet.xlsx" + assert entry["mime_type"].startswith("application/vnd.openxmlformats") + + +def test_malformed_source_dm_row_bumps_skipped_with_error(): + # Renamed-field or partial-serializer response: row lacks + # document_name/document_id. Must surface, not silently disappear. + src = FakeClient( + endpoint=SRC_ENDPOINT, + documents={"src-1": [{"tool": "src-1"}, _doc("ok.pdf")]}, + file_payloads={("src-1", "ok.pdf"): _pdf_payload(b"BYTES")}, + ) + tgt = FakeClient( + endpoint=TGT_ENDPOINT, tools=[{"tool_id": "tgt-1", "tool_name": "demo"}] + ) + remap = RemapTable() + remap.record("custom_tool", "src-1", "tgt-1") + ctx = _ctx(src, tgt, remap=remap) + report = CloneReport() + + result = FilesPhase(ctx).run(report) + + assert result.created == 1 # the well-formed sibling still uploads. + assert result.skipped == 1 # the malformed row. + assert any("malformed source DM row" in e for e in result.errors) + + +def test_skip_strategy_emits_skipped_files_no_traffic(): + src = FakeClient( + endpoint=SRC_ENDPOINT, + documents={"src-1": [_doc("a.pdf"), _doc("b.pdf")]}, + ) + tgt = FakeClient(endpoint=TGT_ENDPOINT) + remap = RemapTable() + remap.record("custom_tool", "src-1", "tgt-1") + ctx = _ctx(src, tgt, remap=remap, file_strategy="skip") + report = CloneReport() + + result = FilesPhase(ctx).run(report) + + assert result.skipped == 2 + assert tgt.uploaded == [] + assert src.download_calls == [] + names = {row["file_name"] for row in report.skipped_files} + assert names == {"a.pdf", "b.pdf"} + assert all(row["source_org_slug"] == "src-org" for row in report.skipped_files) + assert all(row["source_tool_id"] == "src-1" for row in report.skipped_files) + + +def test_dry_run_makes_no_writes_even_for_missing_files(): + src = FakeClient( + endpoint=SRC_ENDPOINT, + documents={"src-1": [_doc("a.pdf")]}, + file_payloads={("src-1", "a.pdf"): _pdf_payload(b"X")}, + ) + tgt = FakeClient( + endpoint=TGT_ENDPOINT, tools=[{"tool_id": "tgt-1", "tool_name": "demo"}] + ) + remap = RemapTable() + remap.record("custom_tool", "src-1", "tgt-1") + ctx = _ctx(src, tgt, remap=remap, dry_run=True) + report = CloneReport() + + result = FilesPhase(ctx).run(report) + + assert result.skipped == 1 + assert result.created == 0 + assert tgt.uploaded == [] + assert src.download_calls == [] + + +def test_transient_503_is_retried_then_succeeds(monkeypatch): + src = FakeClient( + endpoint=SRC_ENDPOINT, + documents={"src-1": [_doc("a.pdf")]}, + file_payloads={("src-1", "a.pdf"): _pdf_payload(b"OK")}, + ) + src.download_errors[("src-1", "a.pdf")] = [ + PlatformAPIError("flaky", status_code=503, body="") + ] + tgt = FakeClient( + endpoint=TGT_ENDPOINT, tools=[{"tool_id": "tgt-1", "tool_name": "demo"}] + ) + remap = RemapTable() + remap.record("custom_tool", "src-1", "tgt-1") + ctx = _ctx(src, tgt, remap=remap) + # Strip the backoff sleep so the test stays fast. + monkeypatch.setattr("unstract.clone.phases.files.time.sleep", lambda *_: None) + report = CloneReport() + + result = FilesPhase(ctx).run(report) + + assert result.created == 1 + assert tgt.uploaded[0]["data"] == b"OK" + + +def test_no_custom_tool_remap_is_noop(): + src = FakeClient(endpoint=SRC_ENDPOINT) + tgt = FakeClient(endpoint=TGT_ENDPOINT) + ctx = _ctx(src, tgt) + report = CloneReport() + + result = FilesPhase(ctx).run(report) + + assert result.created == 0 + assert result.skipped == 0 + assert src.list_calls == [] + + +def test_source_list_failure_isolates_to_that_tool(): + src = FakeClient( + endpoint=SRC_ENDPOINT, + documents={"src-2": [_doc("ok.pdf")]}, + file_payloads={("src-2", "ok.pdf"): _pdf_payload(b"OK")}, + ) + src.list_errors["src-1"] = RuntimeError("source down for this tool") + tgt = FakeClient( + endpoint=TGT_ENDPOINT, + tools=[ + {"tool_id": "tgt-1", "tool_name": "broken"}, + {"tool_id": "tgt-2", "tool_name": "healthy"}, + ], + ) + remap = RemapTable() + remap.record("custom_tool", "src-1", "tgt-1") + remap.record("custom_tool", "src-2", "tgt-2") + ctx = _ctx(src, tgt, remap=remap) + report = CloneReport() + + result = FilesPhase(ctx).run(report) + + assert result.failed == 1 + assert result.created == 1 + assert {u["file_name"] for u in tgt.uploaded} == {"ok.pdf"} + + +def test_upload_failure_records_failed_files_entry(): + src = FakeClient( + endpoint=SRC_ENDPOINT, + documents={"src-1": [_doc("a.pdf")]}, + file_payloads={("src-1", "a.pdf"): _pdf_payload(b"X")}, + ) + tgt = FakeClient( + endpoint=TGT_ENDPOINT, tools=[{"tool_id": "tgt-1", "tool_name": "demo"}] + ) + tgt.upload_errors[("tgt-1", "a.pdf")] = [ + PlatformAPIError("bad", status_code=400, body="bad") + ] + remap = RemapTable() + remap.record("custom_tool", "src-1", "tgt-1") + ctx = _ctx(src, tgt, remap=remap) + report = CloneReport() + + result = FilesPhase(ctx).run(report) + + assert result.failed == 1 + assert result.created == 0 + assert len(report.failed_files) == 1 + entry = report.failed_files[0] + assert entry["file_name"] == "a.pdf" + assert "upload" in entry["error"] + + +@pytest.mark.parametrize( + "mime,raw", + [ + ("text/csv", "name,age\nalice,30"), + ("text/plain", "plain old text"), + ], +) +def test_text_mimes_round_trip_as_utf8(mime, raw): + src = FakeClient( + endpoint=SRC_ENDPOINT, + documents={"src-1": [_doc("data")]}, + file_payloads={("src-1", "data"): _text_payload(raw, mime=mime)}, + ) + tgt = FakeClient( + endpoint=TGT_ENDPOINT, tools=[{"tool_id": "tgt-1", "tool_name": "demo"}] + ) + remap = RemapTable() + remap.record("custom_tool", "src-1", "tgt-1") + ctx = _ctx(src, tgt, remap=remap) + report = CloneReport() + + result = FilesPhase(ctx).run(report) + + assert result.created == 1 + upload = tgt.uploaded[0] + assert upload["data"] == raw.encode("utf-8") + assert upload["mime_type"] == mime + + +def test_default_doc_mirrors_source_selection_by_filename(): + src = FakeClient( + endpoint=SRC_ENDPOINT, + documents={"src-1": [_doc("a.pdf"), _doc("b.pdf")]}, + file_payloads={ + ("src-1", "a.pdf"): _pdf_payload(b"A"), + ("src-1", "b.pdf"): _pdf_payload(b"B"), + }, + # Source's selected doc is b.pdf (document_id="src-b.pdf"). + tools=[{"tool_id": "src-1", "tool_name": "demo", "output": "src-b.pdf"}], + ) + tgt = FakeClient( + endpoint=TGT_ENDPOINT, tools=[{"tool_id": "tgt-1", "tool_name": "demo"}] + ) + remap = RemapTable() + remap.record("custom_tool", "src-1", "tgt-1") + ctx = _ctx(src, tgt, remap=remap) + + FilesPhase(ctx).run(CloneReport()) + + # Target's CustomTool.output now points at b.pdf's new target doc id. + tgt_tool = next(t for t in tgt._tools if t["tool_id"] == "tgt-1") + output_id = tgt_tool["output"] + b_upload = next(d for d in tgt._documents["tgt-1"] if d["document_name"] == "b.pdf") + assert output_id == b_upload["document_id"] + + +def test_default_doc_falls_back_to_first_when_source_has_none(): + src = FakeClient( + endpoint=SRC_ENDPOINT, + documents={"src-1": [_doc("a.pdf")]}, + file_payloads={("src-1", "a.pdf"): _pdf_payload(b"A")}, + # Source has no output set. + tools=[{"tool_id": "src-1", "tool_name": "demo"}], + ) + tgt = FakeClient( + endpoint=TGT_ENDPOINT, tools=[{"tool_id": "tgt-1", "tool_name": "demo"}] + ) + remap = RemapTable() + remap.record("custom_tool", "src-1", "tgt-1") + ctx = _ctx(src, tgt, remap=remap) + + FilesPhase(ctx).run(CloneReport()) + + tgt_tool = next(t for t in tgt._tools if t["tool_id"] == "tgt-1") + a_upload = next(d for d in tgt._documents["tgt-1"] if d["document_name"] == "a.pdf") + assert tgt_tool["output"] == a_upload["document_id"] + + +def test_default_doc_preserves_existing_target_choice(): + src = FakeClient( + endpoint=SRC_ENDPOINT, + documents={"src-1": [_doc("a.pdf")]}, + file_payloads={("src-1", "a.pdf"): _pdf_payload(b"A")}, + tools=[{"tool_id": "src-1", "tool_name": "demo", "output": "src-a.pdf"}], + ) + # Operator already picked a doc on target — re-run must not clobber. + tgt = FakeClient( + endpoint=TGT_ENDPOINT, + tools=[{"tool_id": "tgt-1", "tool_name": "demo", "output": "operator-pick"}], + ) + remap = RemapTable() + remap.record("custom_tool", "src-1", "tgt-1") + ctx = _ctx(src, tgt, remap=remap) + + FilesPhase(ctx).run(CloneReport()) + + tgt_tool = next(t for t in tgt._tools if t["tool_id"] == "tgt-1") + assert tgt_tool["output"] == "operator-pick" diff --git a/tests/clone/test_orchestrator.py b/tests/clone/test_orchestrator.py new file mode 100644 index 0000000..b7a2626 --- /dev/null +++ b/tests/clone/test_orchestrator.py @@ -0,0 +1,159 @@ +"""End-to-end tests for the ``clone()`` orchestrator. + +Coverage: +- Phase ordering matches ``PHASES`` declaration. +- ``include`` / ``exclude`` route phases through ``skipped_phases``. +- ``CloneError`` raised by a phase aborts the run; subsequent phases skipped. +- Both ``PlatformClient`` instances are closed even when a phase aborts. +- ``RemapTable`` snapshot lands on the report. +""" + +from __future__ import annotations + +from unittest.mock import patch + +import pytest + +from unstract.clone import orchestrator +from unstract.clone.context import CloneOptions, OrgEndpoint +from unstract.clone.exceptions import CloneError +from unstract.clone.phases.base import Phase +from unstract.clone.report import CloneReport, PhaseResult + + +class _RecordingPhase(Phase): + """Per-test phase factory; records invocation order on a shared list.""" + + invocations: list[str] = [] + name = "" + + def run(self, report: CloneReport) -> PhaseResult: + _RecordingPhase.invocations.append(self.name) + result = report.get_phase(self.name) + result.created += 1 + # Drop a remap entry so we can prove the snapshot lands on the report. + self.ctx.remap.record(self.name, f"src-{self.name}", f"tgt-{self.name}") + return result + + +def _make_phase(phase_name: str) -> type[Phase]: + return type( + f"FakePhase_{phase_name}", + (_RecordingPhase,), + {"name": phase_name}, + ) + + +@pytest.fixture(autouse=True) +def _reset_invocations(): + _RecordingPhase.invocations = [] + yield + _RecordingPhase.invocations = [] + + +@pytest.fixture +def fake_phases(): + """Replace PHASES with a small deterministic set for the test run.""" + fake = [ + ("adapter", _make_phase("adapter")), + ("connector", _make_phase("connector")), + ("workflow", _make_phase("workflow")), + ] + with patch.object(orchestrator, "PHASES", fake): + yield fake + + +def _src() -> OrgEndpoint: + return OrgEndpoint( + base_url="https://src.example.com", + organization_id="src_org", + platform_key="src-key", + ) + + +def _tgt() -> OrgEndpoint: + return OrgEndpoint( + base_url="https://tgt.example.com", + organization_id="tgt_org", + platform_key="tgt-key", + ) + + +def test_phases_run_in_declared_order(fake_phases): + with patch.object(orchestrator.PlatformClient, "close") as mock_close: + report = orchestrator.clone(_src(), _tgt()) + assert _RecordingPhase.invocations == ["adapter", "connector", "workflow"] + assert [p.name for p in report.phases] == ["adapter", "connector", "workflow"] + # Both clients must close (source + target) regardless of outcome. + assert mock_close.call_count == 2 + + +def test_include_filter_only_runs_listed_phases(fake_phases): + opts = CloneOptions(include=("connector",)) + with patch.object(orchestrator.PlatformClient, "close"): + report = orchestrator.clone(_src(), _tgt(), opts) + assert _RecordingPhase.invocations == ["connector"] + assert set(report.skipped_phases) == {"adapter", "workflow"} + + +def test_exclude_filter_skips_listed_phases(fake_phases): + opts = CloneOptions(exclude=("workflow",)) + with patch.object(orchestrator.PlatformClient, "close"): + report = orchestrator.clone(_src(), _tgt(), opts) + assert _RecordingPhase.invocations == ["adapter", "connector"] + assert report.skipped_phases == ["workflow"] + + +def test_clone_error_aborts_and_skips_subsequent_phases(): + class AbortingPhase(Phase): + name = "connector" + + def run(self, report: CloneReport) -> PhaseResult: + raise CloneError("name collision in 'connector'") + + fake = [ + ("adapter", _make_phase("adapter")), + ("connector", AbortingPhase), + ("workflow", _make_phase("workflow")), + ] + with ( + patch.object(orchestrator, "PHASES", fake), + patch.object(orchestrator.PlatformClient, "close") as mock_close, + ): + report = orchestrator.clone(_src(), _tgt()) + + assert _RecordingPhase.invocations == ["adapter"] + assert report.aborted is True + assert "name collision" in report.abort_reason + # Clients still close on abort. + assert mock_close.call_count == 2 + + +def test_unrelated_exception_propagates_but_still_closes_clients(): + class CrashingPhase(Phase): + name = "connector" + + def run(self, report: CloneReport) -> PhaseResult: + raise RuntimeError("boom") + + fake = [ + ("adapter", _make_phase("adapter")), + ("connector", CrashingPhase), + ] + with ( + patch.object(orchestrator, "PHASES", fake), + patch.object(orchestrator.PlatformClient, "close") as mock_close, + ): + with pytest.raises(RuntimeError, match="boom"): + orchestrator.clone(_src(), _tgt()) + assert mock_close.call_count == 2 + + +def test_remap_snapshot_populated_on_report(fake_phases): + with patch.object(orchestrator.PlatformClient, "close"): + report = orchestrator.clone(_src(), _tgt()) + assert report.remap_snapshot == { + "adapter": {"src-adapter": "tgt-adapter"}, + "connector": {"src-connector": "tgt-connector"}, + "workflow": {"src-workflow": "tgt-workflow"}, + } diff --git a/tests/clone/test_phase_concurrency.py b/tests/clone/test_phase_concurrency.py new file mode 100644 index 0000000..0179f1b --- /dev/null +++ b/tests/clone/test_phase_concurrency.py @@ -0,0 +1,291 @@ +"""Thread-safety checks for ``Phase.parallel_map``. + +Coverage: +- Many-item fan-out produces exact counts + remap entries with no loss. +- Sequential path (``concurrency=1``) skips the thread pool entirely + while preserving identical behaviour. +- ``CloneError`` raised inside a worker propagates out of ``parallel_map`` + so the orchestrator's abort handling engages. +- A non-``CloneError`` exception inside a worker still propagates. + +We use a fake client that holds a lock around its own mutable state and +injects a small sleep per HTTP call to force real interleaving between +workers, then assert the phase's lock-guarded code keeps counters and +the remap table consistent. +""" + +from __future__ import annotations + +import threading +import time + +import pytest + +from unstract.clone.context import CloneContext, CloneOptions, RemapTable +from unstract.clone.exceptions import NameConflictError +from unstract.clone.phases.adapter import AdapterPhase +from unstract.clone.phases.tag import TagPhase +from unstract.clone.report import CloneReport + + +class _ThreadSafeAdapterClient: + """Adapter FakeClient with a lock around mutable state + per-call sleep + so workers actually interleave under ThreadPoolExecutor. + """ + + POST_SCHEMA = frozenset( + { + "adapter_id", + "adapter_name", + "adapter_type", + "adapter_metadata", + "description", + } + ) + + def __init__(self, adapters=None, sleep_seconds: float = 0.005): + self._adapters: list[dict] = list(adapters or []) + self.posts: list[dict] = [] + self._next_id = 1 + self._lock = threading.Lock() + self._sleep = sleep_seconds + + def get_post_schema(self, entity_path): + return self.POST_SCHEMA + + def list_adapters(self, *, name=None, adapter_type=None): + time.sleep(self._sleep) + with self._lock: + snap = list(self._adapters) + result = snap + if name is not None: + result = [a for a in result if a["adapter_name"] == name] + if adapter_type is not None: + result = [a for a in result if a["adapter_type"] == adapter_type] + return [{k: v for k, v in a.items() if k != "adapter_metadata"} for a in result] + + def get_adapter(self, adapter_pk): + time.sleep(self._sleep) + with self._lock: + for a in self._adapters: + if a["id"] == adapter_pk: + return dict(a) + raise KeyError(adapter_pk) + + def create_adapter(self, payload): + time.sleep(self._sleep) + with self._lock: + new = dict(payload) + new["id"] = f"tgt-{self._next_id:08d}-0000-0000-0000-000000000000" + self._next_id += 1 + self._adapters.append(new) + self.posts.append(new) + return new + + +def _src_adapter(id_, name, atype="LLM"): + return { + "id": id_, + "adapter_id": "openai-llm-v2", + "adapter_name": name, + "adapter_type": atype, + "adapter_metadata": {"api_key": "sk-secret", "model": "gpt-4"}, + "description": f"{name} desc", + } + + +def _ctx(source, target, **opt_overrides): + return CloneContext( + source=source, + target=target, + options=CloneOptions(**opt_overrides), + remap=RemapTable(), + ) + + +def test_parallel_map_preserves_counts_with_many_items(): + items = 50 + src = _ThreadSafeAdapterClient( + [_src_adapter(f"src-{i:03d}", f"adapter-{i:03d}") for i in range(items)] + ) + tgt = _ThreadSafeAdapterClient() + ctx = _ctx(src, tgt, concurrency=8) + report = CloneReport() + + result = AdapterPhase(ctx).run(report) + + assert result.created == items + assert result.adopted == 0 + assert result.skipped == 0 + assert result.failed == 0 + assert len(tgt.posts) == items + remap = ctx.remap.snapshot().get("adapter", {}) + assert len(remap) == items + # Every source id should be mapped to a fresh target id. + assert set(remap.keys()) == {f"src-{i:03d}" for i in range(items)} + assert len(set(remap.values())) == items + + +def test_concurrency_one_runs_sequentially_with_no_executor(monkeypatch): + """With concurrency=1 we should never hit ThreadPoolExecutor.""" + sentinel = {"executor_used": False} + + import unstract.clone.phases.base as base_mod + + original = base_mod.ThreadPoolExecutor + + class _Forbidden: + def __init__(self, *a, **kw): + sentinel["executor_used"] = True + raise AssertionError("ThreadPoolExecutor must not be used at concurrency=1") + + monkeypatch.setattr(base_mod, "ThreadPoolExecutor", _Forbidden) + src = _ThreadSafeAdapterClient( + [_src_adapter(f"src-{i}", f"a-{i}") for i in range(5)] + ) + tgt = _ThreadSafeAdapterClient() + ctx = _ctx(src, tgt, concurrency=1) + report = CloneReport() + + result = AdapterPhase(ctx).run(report) + + assert result.created == 5 + assert sentinel["executor_used"] is False + # restore for any other tests in same module (monkeypatch undoes on teardown). + base_mod.ThreadPoolExecutor = original # noqa: F841 + + +class _AbortingAdapterClient(_ThreadSafeAdapterClient): + """As parent, but ``list_adapters`` claims the named adapter already + exists on target — used to trigger NameConflictError when the phase + is run with ``on_name_conflict='abort'``.""" + + def list_adapters(self, *, name=None, adapter_type=None): + time.sleep(self._sleep) + return [ + { + "id": "tgt-existing-0001", + "adapter_name": name or "x", + "adapter_type": adapter_type or "LLM", + } + ] + + +def test_clone_error_in_worker_propagates_under_concurrency(): + src = _ThreadSafeAdapterClient( + [_src_adapter(f"src-{i}", f"clash-{i}") for i in range(10)] + ) + tgt = _AbortingAdapterClient() + ctx = _ctx(src, tgt, concurrency=4, on_name_conflict="abort") + report = CloneReport() + + with pytest.raises(NameConflictError): + AdapterPhase(ctx).run(report) + + +class _UnexpectedAdapterClient(_ThreadSafeAdapterClient): + """One of the GETs blows up with a non-Clone RuntimeError.""" + + def __init__(self, *a, fail_on_name: str, **kw): + super().__init__(*a, **kw) + self._fail_on_name = fail_on_name + + def get_adapter(self, adapter_pk): + snap = super().get_adapter(adapter_pk) + if snap["adapter_name"] == self._fail_on_name: + raise RuntimeError("transport boom") + return snap + + +def test_non_clone_exception_recorded_as_failed_not_raised(): + """Workers convert non-Clone errors into ``result.failed`` counts; + they don't escape the phase. (CloneError is the abort signal — + arbitrary exceptions are per-item failures.)""" + src = _UnexpectedAdapterClient( + adapters=[_src_adapter(f"src-{i}", f"adapter-{i}") for i in range(10)], + fail_on_name="adapter-3", + ) + tgt = _ThreadSafeAdapterClient() + ctx = _ctx(src, tgt, concurrency=4) + report = CloneReport() + + result = AdapterPhase(ctx).run(report) + + assert result.failed == 1 + # The other 9 still created successfully. + assert result.created == 9 + assert len(tgt.posts) == 9 + + +class _TagClient: + """Minimal tag fake with thread-safe state + per-call sleep.""" + + POST_SCHEMA = frozenset({"name", "description"}) + + def __init__(self, tags=None, sleep_seconds: float = 0.005): + self._tags: list[dict] = list(tags or []) + self.posts: list[dict] = [] + self._next_id = 1 + self._lock = threading.Lock() + self._sleep = sleep_seconds + + def get_post_schema(self, entity_path): + return self.POST_SCHEMA + + def list_tags(self, *, name=None): + time.sleep(self._sleep) + with self._lock: + snap = list(self._tags) + if name is not None: + snap = [t for t in snap if t["name"] == name] + return snap + + def create_tag(self, payload): + time.sleep(self._sleep) + with self._lock: + new = dict(payload) + new["id"] = f"tag-tgt-{self._next_id:04d}" + self._next_id += 1 + self._tags.append(new) + self.posts.append(new) + return new + + +def test_tag_phase_parallel_remap_table_consistent(): + """Distinct phase exercising the same parallel_map path — ensures the + helper isn't accidentally adapter-specific. + """ + src = _TagClient( + [{"id": f"tag-src-{i}", "name": f"tag-{i:03d}"} for i in range(30)] + ) + tgt = _TagClient() + ctx = _ctx(src, tgt, concurrency=8) + report = CloneReport() + + result = TagPhase(ctx).run(report) + + assert result.created == 30 + assert result.failed == 0 + remap = ctx.remap.snapshot().get("tag", {}) + assert len(remap) == 30 + # remap value uniqueness — no two source tags mapped to the same target id. + assert len(set(remap.values())) == 30 + + +def test_parallel_map_empty_input_no_executor(monkeypatch): + """No items → no thread pool, no work.""" + import unstract.clone.phases.base as base_mod + + class _Forbidden: + def __init__(self, *a, **kw): + raise AssertionError("Should not create pool for empty input") + + monkeypatch.setattr(base_mod, "ThreadPoolExecutor", _Forbidden) + src = _ThreadSafeAdapterClient([]) + tgt = _ThreadSafeAdapterClient() + ctx = _ctx(src, tgt, concurrency=8) + report = CloneReport() + + result = AdapterPhase(ctx).run(report) + assert result.created == 0 + assert result.adopted == 0 diff --git a/tests/clone/test_pipeline_phase.py b/tests/clone/test_pipeline_phase.py new file mode 100644 index 0000000..e69c3cb --- /dev/null +++ b/tests/clone/test_pipeline_phase.py @@ -0,0 +1,265 @@ +"""Tests for ``PipelinePhase``. + +Coverage: +- happy path: source ETL/TASK pipelines created with workflow FK remapped. +- DEFAULT and APP types are skipped (out of clone scope). +- adopt path on name conflict. +- skipped when workflow remap missing. +- dry-run is a no-op. +- abort raises ``NameConflictError``. +- extra source keys produce a warning, never a failure. +""" + +from __future__ import annotations + +import logging + +import pytest + +from unstract.clone.context import ( + CloneContext, + CloneOptions, + RemapTable, +) +from unstract.clone.exceptions import NameConflictError +from unstract.clone.phases.pipeline import PipelinePhase +from unstract.clone.report import CloneReport + +PIPELINE_POST_SCHEMA = frozenset( + { + "pipeline_name", + "workflow", + "pipeline_type", + "cron_string", + "app_id", + "app_icon", + "app_url", + "access_control_bundle_id", + "shared_users", + "shared_to_org", + } +) + + +class FakeClient: + def __init__(self, pipelines: list[dict] | None = None): + self.pipelines: list[dict] = list(pipelines or []) + self.posts: list[dict] = [] + self.keys_by_pipeline: dict[str, list[dict]] = {} + self._next = 1 + + def get_post_schema(self, entity_path: str) -> frozenset[str]: + return PIPELINE_POST_SCHEMA + + def list_pipelines( + self, *, name: str | None = None, pipeline_type: str | None = None + ): + result = self.pipelines + if name is not None: + result = [p for p in result if p["pipeline_name"] == name] + if pipeline_type is not None: + result = [p for p in result if p.get("pipeline_type") == pipeline_type] + return list(result) + + def get_pipeline(self, pipeline_id: str) -> dict: + for p in self.pipelines: + if p["id"] == pipeline_id: + return dict(p) + raise KeyError(pipeline_id) + + def create_pipeline(self, payload: dict) -> dict: + new = dict(payload) + new["id"] = f"tgt-pipeline-{self._next:04d}" + self._next += 1 + self.pipelines.append(new) + self.posts.append(new) + return new + + def list_pipeline_keys(self, pipeline_id: str) -> list[dict]: + return list(self.keys_by_pipeline.get(pipeline_id, [])) + + +def _src_pipeline( + id_: str, + name: str, + workflow_id: str, + *, + pipeline_type: str = "ETL", + cron_string: str | None = None, +) -> dict: + return { + "id": id_, + "pipeline_name": name, + "workflow": workflow_id, + "workflow_id": workflow_id, + "workflow_name": "wf", + "pipeline_type": pipeline_type, + "active": True, + "scheduled": cron_string is not None, + "cron_string": cron_string, + "app_id": None, + "app_icon": None, + "app_url": None, + "access_control_bundle_id": None, + "shared_users": [], + "shared_to_org": False, + } + + +def _ctx(source, target, *, remap=None, **opt_overrides): + return CloneContext( + source=source, + target=target, + options=CloneOptions(**opt_overrides), + remap=remap or RemapTable(), + ) + + +def test_happy_path_creates_pipeline_with_remapped_workflow(): + src = FakeClient([_src_pipeline("src-pl-1", "Daily Invoices", "wf-src-1")]) + tgt = FakeClient() + remap = RemapTable() + remap.record("workflow", "wf-src-1", "wf-tgt-1") + ctx = _ctx(src, tgt, remap=remap) + + result = PipelinePhase(ctx).run(CloneReport()) + + assert result.created == 1 + assert result.failed == 0 + posted = tgt.posts[0] + assert posted["pipeline_name"] == "Daily Invoices" + assert posted["workflow"] == "wf-tgt-1" + assert ctx.remap.resolve("pipeline", "src-pl-1") == posted["id"] + + +def test_create_uses_per_id_get_not_stripped_list_payload(): + # list_pipelines can omit fields the create serializer expects. Phase + # must re-fetch the full record via get_pipeline before POSTing. + full = _src_pipeline("src-pl-1", "Daily Invoices", "wf-src-1") + full["cron_string"] = "0 5 * * *" # only present on detail serializer. + stripped = {k: v for k, v in full.items() if k not in ("cron_string",)} + + class StripListFakeClient(FakeClient): + def list_pipelines(self, *, name=None, pipeline_type=None): + base = ( + [stripped] + if ( + (name is None or stripped["pipeline_name"] == name) + and ( + pipeline_type is None + or stripped["pipeline_type"] == pipeline_type + ) + ) + else [] + ) + return list(base) + + def get_pipeline(self, pipeline_id): + assert pipeline_id == full["id"] + return dict(full) + + src = StripListFakeClient([full]) + tgt = FakeClient() + remap = RemapTable() + remap.record("workflow", "wf-src-1", "wf-tgt-1") + ctx = _ctx(src, tgt, remap=remap) + + PipelinePhase(ctx).run(CloneReport()) + + posted = tgt.posts[0] + # cron_string only existed on the detail GET — proves we did NOT + # POST the stripped list-item payload. + assert posted["cron_string"] == "0 5 * * *" + + +def test_default_and_app_pipeline_types_are_skipped(): + src = FakeClient( + [ + _src_pipeline( + "src-1", "default-legacy", "wf-src-1", pipeline_type="DEFAULT" + ), + _src_pipeline("src-2", "streamlit-app", "wf-src-1", pipeline_type="APP"), + _src_pipeline("src-3", "real-etl", "wf-src-1", pipeline_type="ETL"), + ] + ) + tgt = FakeClient() + remap = RemapTable() + remap.record("workflow", "wf-src-1", "wf-tgt-1") + ctx = _ctx(src, tgt, remap=remap) + + result = PipelinePhase(ctx).run(CloneReport()) + + assert result.created == 1 + assert len(tgt.posts) == 1 + assert tgt.posts[0]["pipeline_name"] == "real-etl" + + +def test_adopts_existing_pipeline_by_name(): + src = FakeClient([_src_pipeline("src-pl-1", "Daily Invoices", "wf-src-1")]) + tgt = FakeClient([{"id": "tgt-existing", "pipeline_name": "Daily Invoices"}]) + remap = RemapTable() + remap.record("workflow", "wf-src-1", "wf-tgt-1") + ctx = _ctx(src, tgt, remap=remap) + + result = PipelinePhase(ctx).run(CloneReport()) + + assert result.adopted == 1 + assert result.created == 0 + assert tgt.posts == [] + assert ctx.remap.resolve("pipeline", "src-pl-1") == "tgt-existing" + + +def test_skipped_when_workflow_remap_missing(): + src = FakeClient([_src_pipeline("src-pl-1", "Orphan", "wf-src-1")]) + tgt = FakeClient() + ctx = _ctx(src, tgt) # No workflow remap. + + result = PipelinePhase(ctx).run(CloneReport()) + + assert result.skipped == 1 + assert result.failed == 0 + assert tgt.posts == [] + + +def test_dry_run_makes_no_writes(): + src = FakeClient([_src_pipeline("src-pl-1", "Daily Invoices", "wf-src-1")]) + tgt = FakeClient() + remap = RemapTable() + remap.record("workflow", "wf-src-1", "wf-tgt-1") + ctx = _ctx(src, tgt, remap=remap, dry_run=True) + + result = PipelinePhase(ctx).run(CloneReport()) + + assert result.skipped == 1 + assert tgt.posts == [] + + +def test_abort_on_name_conflict_raises(): + src = FakeClient([_src_pipeline("src-pl-1", "Daily Invoices", "wf-src-1")]) + tgt = FakeClient([{"id": "tgt-existing", "pipeline_name": "Daily Invoices"}]) + remap = RemapTable() + remap.record("workflow", "wf-src-1", "wf-tgt-1") + ctx = _ctx(src, tgt, remap=remap, on_name_conflict="abort") + + with pytest.raises(NameConflictError): + PipelinePhase(ctx).run(CloneReport()) + + +def test_extra_source_keys_log_warning_not_failure(caplog): + src = FakeClient([_src_pipeline("src-pl-1", "Daily Invoices", "wf-src-1")]) + src.keys_by_pipeline["src-pl-1"] = [ + {"id": "k1", "is_active": True}, + {"id": "k2", "is_active": True}, + {"id": "k3", "is_active": False}, + ] + tgt = FakeClient() + remap = RemapTable() + remap.record("workflow", "wf-src-1", "wf-tgt-1") + ctx = _ctx(src, tgt, remap=remap) + + with caplog.at_level(logging.WARNING, logger="unstract.clone.phases.pipeline"): + result = PipelinePhase(ctx).run(CloneReport()) + + assert result.created == 1 + assert result.failed == 0 + assert any("2 active API keys" in r.message for r in caplog.records) diff --git a/tests/clone/test_remap_table.py b/tests/clone/test_remap_table.py new file mode 100644 index 0000000..6a13754 --- /dev/null +++ b/tests/clone/test_remap_table.py @@ -0,0 +1,36 @@ +"""Tests for ``RemapTable``.""" + +from unstract.clone.context import RemapTable + + +def test_record_and_resolve_per_entity(): + t = RemapTable() + t.record("adapter", "src-1", "tgt-1") + t.record("adapter", "src-2", "tgt-2") + t.record("connector", "src-1", "tgt-99") + + assert t.resolve("adapter", "src-1") == "tgt-1" + assert t.resolve("adapter", "src-2") == "tgt-2" + assert t.resolve("connector", "src-1") == "tgt-99" + + +def test_resolve_missing_returns_none(): + t = RemapTable() + assert t.resolve("adapter", "nope") is None + assert t.resolve_any("nope") is None + + +def test_resolve_any_searches_across_entities(): + t = RemapTable() + t.record("adapter", "src-a", "tgt-a") + t.record("workflow", "src-w", "tgt-w") + assert t.resolve_any("src-a") == "tgt-a" + assert t.resolve_any("src-w") == "tgt-w" + + +def test_snapshot_is_independent_copy(): + t = RemapTable() + t.record("adapter", "src-1", "tgt-1") + snap = t.snapshot() + t.record("adapter", "src-2", "tgt-2") + assert "src-2" not in snap["adapter"] diff --git a/tests/clone/test_tag_phase.py b/tests/clone/test_tag_phase.py new file mode 100644 index 0000000..f6086a9 --- /dev/null +++ b/tests/clone/test_tag_phase.py @@ -0,0 +1,109 @@ +"""Tests for ``TagPhase``. + +Tag is the simplest entity — no encryption, no list-vs-detail divergence. +Suite covers happy / idempotency / dry-run / abort. +""" + +from __future__ import annotations + +import pytest + +from unstract.clone.context import ( + CloneContext, + CloneOptions, + RemapTable, +) +from unstract.clone.exceptions import NameConflictError +from unstract.clone.phases.tag import TagPhase +from unstract.clone.report import CloneReport + + +class FakeClient: + POST_SCHEMA = frozenset({"name", "description"}) + + def __init__(self, tags: list[dict] | None = None): + self.tags: list[dict] = list(tags or []) + self.posts: list[dict] = [] + self._next_id = 1 + + def get_post_schema(self, entity_path): + return self.POST_SCHEMA + + def list_tags(self, *, name=None): + result = self.tags + if name is not None: + result = [t for t in result if t["name"] == name] + return list(result) + + def create_tag(self, payload): + new = dict(payload) + new["id"] = f"tgt-{self._next_id:08d}-0000-0000-0000-000000000000" + self._next_id += 1 + self.tags.append(new) + self.posts.append(new) + return new + + +def _src(id_, name): + return {"id": id_, "name": name, "description": f"{name} desc"} + + +def _ctx(source, target, **opt_overrides): + return CloneContext( + source=source, + target=target, + options=CloneOptions(**opt_overrides), + remap=RemapTable(), + ) + + +def test_happy_path_creates_all_and_records_remap(): + src = FakeClient([_src("src-a", "billing"), _src("src-b", "finance")]) + tgt = FakeClient() + ctx = _ctx(src, tgt) + report = CloneReport() + + result = TagPhase(ctx).run(report) + + assert result.created == 2 + assert result.adopted == 0 + assert len(tgt.posts) == 2 + assert ctx.remap.resolve("tag", "src-a") == tgt.posts[0]["id"] + assert ctx.remap.resolve("tag", "src-b") == tgt.posts[1]["id"] + + +def test_idempotency_zero_creates_on_rerun(): + src = FakeClient([_src("src-a", "billing")]) + tgt = FakeClient([{"id": "preexisting", "name": "billing", "description": "x"}]) + ctx = _ctx(src, tgt, on_name_conflict="adopt") + report = CloneReport() + + result = TagPhase(ctx).run(report) + + assert result.created == 0 + assert result.adopted == 1 + assert tgt.posts == [] + assert ctx.remap.resolve("tag", "src-a") == "preexisting" + + +def test_dry_run_makes_no_posts(): + src = FakeClient([_src("src-a", "billing")]) + tgt = FakeClient() + ctx = _ctx(src, tgt, dry_run=True) + report = CloneReport() + + result = TagPhase(ctx).run(report) + + assert result.skipped == 1 + assert result.created == 0 + assert tgt.posts == [] + + +def test_abort_on_name_conflict_raises(): + src = FakeClient([_src("src-a", "billing")]) + tgt = FakeClient([{"id": "preexisting", "name": "billing", "description": "x"}]) + ctx = _ctx(src, tgt, on_name_conflict="abort") + report = CloneReport() + + with pytest.raises(NameConflictError): + TagPhase(ctx).run(report) diff --git a/tests/clone/test_tool_instance_phase.py b/tests/clone/test_tool_instance_phase.py new file mode 100644 index 0000000..180c285 --- /dev/null +++ b/tests/clone/test_tool_instance_phase.py @@ -0,0 +1,248 @@ +"""Tests for ``ToolInstancePhase``. + +ToolInstance is unique among phases: +- The source list of "things to clone" comes from the workflow remap + table, not a top-level entity list. +- Create is a two-step dance (POST bare, PATCH metadata) because the + backend rebuilds metadata from defaults on POST. +""" + +from __future__ import annotations + +from unstract.clone.context import ( + CloneContext, + CloneOptions, + RemapTable, +) +from unstract.clone.phases.tool_instance import ToolInstancePhase +from unstract.clone.report import CloneReport + + +class FakeClient: + def __init__(self) -> None: + # Keyed by workflow_id -> list of tool_instances. + self.instances: dict[str, list[dict]] = {} + self.create_calls: list[dict] = [] + self.patch_calls: list[tuple[str, dict]] = [] + self._next = 1 + + def _mint(self) -> str: + s = f"tgt-ti-{self._next:04d}" + self._next += 1 + return s + + def list_tool_instances(self, *, workflow_id: str | None = None) -> list[dict]: + if workflow_id is None: + return [ti for instances in self.instances.values() for ti in instances] + return list(self.instances.get(workflow_id, [])) + + def create_tool_instance(self, payload: dict) -> dict: + wf = payload["workflow_id"] + new = {**payload, "id": self._mint(), "metadata": {"defaults": True}} + self.instances.setdefault(wf, []).append(new) + self.create_calls.append(new) + return new + + def update_tool_instance_metadata(self, instance_id: str, metadata: dict) -> dict: + self.patch_calls.append((instance_id, metadata)) + for wf_instances in self.instances.values(): + for ti in wf_instances: + if ti["id"] == instance_id: + ti["metadata"] = metadata + return ti + raise KeyError(instance_id) + + +def _ctx(source, target, *, remap=None, **opt_overrides): + return CloneContext( + source=source, + target=target, + options=CloneOptions(**opt_overrides), + remap=remap or RemapTable(), + ) + + +def _src_ti(ti_id: str, wf_id: str, tool_id: str, metadata: dict) -> dict: + return { + "id": ti_id, + "workflow": wf_id, + "tool_id": tool_id, + "metadata": metadata, + "step": 1, + } + + +SRC_WF = "10000000-0000-0000-0000-000000000001" +TGT_WF = "20000000-0000-0000-0000-000000000001" +SRC_REG = "30000000-0000-0000-0000-000000000001" +TGT_REG = "40000000-0000-0000-0000-000000000001" + + +def _seed_remap() -> RemapTable: + remap = RemapTable() + remap.record("workflow", SRC_WF, TGT_WF) + remap.record("prompt_studio_registry", SRC_REG, TGT_REG) + return remap + + +def test_happy_path_creates_instance_then_patches_metadata(): + src = FakeClient() + src.instances[SRC_WF] = [ + _src_ti( + "src-ti-1", + SRC_WF, + SRC_REG, + { + "llm": "My OpenAI", + "embedding": "MyEmb", + # Identity fields that the backend populated server-side + # at source create time — must NOT cross the org boundary. + "tenant_id": "src-org", + "prompt_registry_id": "src-registry-uuid", + "tool_instance_id": "src-ti-1-pk", + }, + ) + ] + tgt = FakeClient() + ctx = _ctx(src, tgt, remap=_seed_remap()) + + result = ToolInstancePhase(ctx).run(CloneReport()) + + assert result.created == 1 + assert result.failed == 0 + assert len(tgt.create_calls) == 1 + posted = tgt.create_calls[0] + assert posted["workflow_id"] == TGT_WF + assert posted["tool_id"] == TGT_REG + # PATCH carries source settings but stamps identity fields with + # target values — backend PATCH overwrites the whole metadata dict. + assert len(tgt.patch_calls) == 1 + patched_id, patched_metadata = tgt.patch_calls[0] + assert patched_id == posted["id"] + assert patched_metadata == { + "llm": "My OpenAI", + "embedding": "MyEmb", + "prompt_registry_id": TGT_REG, + "tool_instance_id": posted["id"], + } + assert ctx.remap.resolve("tool_instance", "src-ti-1") == posted["id"] + + +def test_skip_when_registry_remap_missing(): + src = FakeClient() + src.instances[SRC_WF] = [_src_ti("src-ti-1", SRC_WF, "unknown-reg", {})] + tgt = FakeClient() + remap = RemapTable() + remap.record("workflow", SRC_WF, TGT_WF) + # No prompt_studio_registry remap entry → SDK must skip. + ctx = _ctx(src, tgt, remap=remap) + + result = ToolInstancePhase(ctx).run(CloneReport()) + + assert result.skipped == 1 + assert result.created == 0 + assert tgt.create_calls == [] + + +def test_adopt_existing_target_instance_and_repatch_metadata(): + src = FakeClient() + src_meta = {"llm": "My OpenAI"} + src.instances[SRC_WF] = [_src_ti("src-ti-1", SRC_WF, SRC_REG, src_meta)] + tgt = FakeClient() + tgt.instances[TGT_WF] = [ + {"id": "tgt-pre-ti", "workflow": TGT_WF, "tool_id": TGT_REG, "metadata": {}} + ] + ctx = _ctx(src, tgt, remap=_seed_remap()) + + result = ToolInstancePhase(ctx).run(CloneReport()) + + assert result.adopted == 1 + assert result.created == 0 + assert tgt.create_calls == [] + # PATCH still fires on adopt and stamps identity fields with target + # values so the runtime can resolve the registry. + assert tgt.patch_calls == [ + ( + "tgt-pre-ti", + { + "llm": "My OpenAI", + "prompt_registry_id": TGT_REG, + "tool_instance_id": "tgt-pre-ti", + }, + ) + ] + assert ctx.remap.resolve("tool_instance", "src-ti-1") == "tgt-pre-ti" + + +def test_no_op_when_no_workflows_in_remap(): + src = FakeClient() + tgt = FakeClient() + ctx = _ctx(src, tgt, remap=RemapTable()) + + result = ToolInstancePhase(ctx).run(CloneReport()) + + assert result.created == 0 + assert result.skipped == 0 + assert tgt.create_calls == [] + + +def test_broken_adapter_refs_bumps_skipped_and_records_error(): + src = FakeClient() + src.instances[SRC_WF] = [ + _src_ti( + "src-ti-1", + SRC_WF, + SRC_REG, + {"llm": "[DELETED ADAPTER] My OpenAI", "embedding": "MyEmb"}, + ) + ] + tgt = FakeClient() + ctx = _ctx(src, tgt, remap=_seed_remap()) + + result = ToolInstancePhase(ctx).run(CloneReport()) + + assert result.created == 1 + assert result.skipped == 1 + assert tgt.patch_calls == [] + assert any("stale adapter refs" in e for e in result.errors) + + +def test_dry_run_does_not_create_or_patch(): + src = FakeClient() + src.instances[SRC_WF] = [_src_ti("src-ti-1", SRC_WF, SRC_REG, {"x": 1})] + tgt = FakeClient() + ctx = _ctx(src, tgt, remap=_seed_remap(), dry_run=True) + + result = ToolInstancePhase(ctx).run(CloneReport()) + + assert result.skipped == 1 + assert tgt.create_calls == [] + assert tgt.patch_calls == [] + + +def test_dry_run_on_adopt_path_does_not_repatch_target(): + # Target already has a tool_instance for the target workflow. On a + # dry-run, we must NOT PATCH its metadata — the adopt branch used to + # fall through to the PATCH call. + src = FakeClient() + src.instances[SRC_WF] = [_src_ti("src-ti-1", SRC_WF, SRC_REG, {"llm": "My OpenAI"})] + tgt = FakeClient() + tgt.instances[TGT_WF] = [ + { + "id": "tgt-pre-ti", + "workflow": TGT_WF, + "tool_id": TGT_REG, + "metadata": {"existing": "untouched"}, + "step": 1, + } + ] + ctx = _ctx(src, tgt, remap=_seed_remap(), dry_run=True) + + result = ToolInstancePhase(ctx).run(CloneReport()) + + assert result.skipped == 1 + assert result.adopted == 0 + assert tgt.create_calls == [] + assert tgt.patch_calls == [] + # Remap still gets recorded so downstream dry-run output is coherent. + assert ctx.remap.resolve("tool_instance", "src-ti-1") == "tgt-pre-ti" diff --git a/tests/clone/test_walker.py b/tests/clone/test_walker.py new file mode 100644 index 0000000..5ae0301 --- /dev/null +++ b/tests/clone/test_walker.py @@ -0,0 +1,54 @@ +"""Tests for ``remap_uuids``.""" + +from unstract.clone.context import RemapTable +from unstract.clone.walker import remap_uuids + +SRC_A = "11111111-1111-1111-1111-111111111111" +TGT_A = "22222222-2222-2222-2222-222222222222" +SRC_B = "33333333-3333-3333-3333-333333333333" +TGT_B = "44444444-4444-4444-4444-444444444444" +UNRELATED = "55555555-5555-5555-5555-555555555555" + + +def _populated_remap(): + t = RemapTable() + t.record("adapter", SRC_A, TGT_A) + t.record("workflow", SRC_B, TGT_B) + return t + + +def test_remaps_mapped_uuid_string(): + assert remap_uuids(SRC_A, _populated_remap()) == TGT_A + + +def test_leaves_unmapped_uuid_untouched(): + assert remap_uuids(UNRELATED, _populated_remap()) == UNRELATED + + +def test_leaves_non_uuid_string_alone(): + assert remap_uuids("hello-world", _populated_remap()) == "hello-world" + + +def test_remaps_inside_nested_dict_and_list(): + payload = { + "id": SRC_A, + "config": { + "refs": [SRC_B, "not-a-uuid", UNRELATED], + "nested": {"adapter_id": SRC_A}, + }, + "count": 42, + } + result = remap_uuids(payload, _populated_remap()) + assert result == { + "id": TGT_A, + "config": { + "refs": [TGT_B, "not-a-uuid", UNRELATED], + "nested": {"adapter_id": TGT_A}, + }, + "count": 42, + } + + +def test_handles_non_string_scalars(): + payload = {"a": 1, "b": True, "c": None, "d": 3.14} + assert remap_uuids(payload, _populated_remap()) == payload diff --git a/tests/clone/test_workflow_endpoint_phase.py b/tests/clone/test_workflow_endpoint_phase.py new file mode 100644 index 0000000..811488f --- /dev/null +++ b/tests/clone/test_workflow_endpoint_phase.py @@ -0,0 +1,242 @@ +"""Tests for ``WorkflowEndpointPhase``. + +WorkflowEndpoints are PATCH-only — backend auto-creates them on workflow +POST. Tests verify that the SDK: +- pairs source/target endpoints by ``endpoint_type``; +- remaps the embedded ``connector_instance`` UUID; +- walker-rewrites UUIDs nested in ``configuration``; +- silently leaves connector_instance_id null when no remap exists. +""" + +from __future__ import annotations + +from unstract.clone.context import ( + CloneContext, + CloneOptions, + RemapTable, +) +from unstract.clone.phases.workflow_endpoint import WorkflowEndpointPhase +from unstract.clone.report import CloneReport + + +class FakeClient: + def __init__(self) -> None: + self.endpoints: dict[str, list[dict]] = {} + self.patch_calls: list[tuple[str, dict]] = [] + + def list_workflow_endpoints(self, *, workflow_id: str | None = None) -> list[dict]: + if workflow_id is None: + return [ep for eps in self.endpoints.values() for ep in eps] + return list(self.endpoints.get(workflow_id, [])) + + def update_workflow_endpoint(self, endpoint_id: str, payload: dict) -> dict: + self.patch_calls.append((endpoint_id, payload)) + for eps in self.endpoints.values(): + for ep in eps: + if ep["id"] == endpoint_id: + ep.update(payload) + return ep + raise KeyError(endpoint_id) + + +def _ctx(source, target, *, remap=None, **opt_overrides): + return CloneContext( + source=source, + target=target, + options=CloneOptions(**opt_overrides), + remap=remap or RemapTable(), + ) + + +SRC_WF = "10000000-0000-0000-0000-000000000001" +TGT_WF = "20000000-0000-0000-0000-000000000001" +SRC_CONN = "30000000-0000-0000-0000-000000000001" +TGT_CONN = "40000000-0000-0000-0000-000000000001" + + +def _src_endpoint(ep_id, etype, connector_id, configuration): + return { + "id": ep_id, + "workflow": SRC_WF, + "endpoint_type": etype, + "connection_type": "FILESYSTEM", + "configuration": configuration, + "connector_instance": {"id": connector_id, "connector_name": "src-conn"}, + } + + +def _tgt_endpoint(ep_id, etype): + return { + "id": ep_id, + "workflow": TGT_WF, + "endpoint_type": etype, + "connection_type": "", + "configuration": {}, + "connector_instance": None, + } + + +def _seed_remap() -> RemapTable: + remap = RemapTable() + remap.record("workflow", SRC_WF, TGT_WF) + remap.record("connector", SRC_CONN, TGT_CONN) + return remap + + +def test_pairs_endpoints_by_type_and_remaps_connector(): + src = FakeClient() + src.endpoints[SRC_WF] = [ + _src_endpoint( + "src-ep-source", + "SOURCE", + SRC_CONN, + {"connector_id": SRC_CONN, "path": "/in"}, + ), + _src_endpoint( + "src-ep-dest", + "DESTINATION", + SRC_CONN, + {"connector_id": SRC_CONN, "path": "/out"}, + ), + ] + tgt = FakeClient() + tgt.endpoints[TGT_WF] = [ + _tgt_endpoint("tgt-ep-source", "SOURCE"), + _tgt_endpoint("tgt-ep-dest", "DESTINATION"), + ] + ctx = _ctx(src, tgt, remap=_seed_remap()) + + result = WorkflowEndpointPhase(ctx).run(CloneReport()) + + assert result.created == 2 + assert result.failed == 0 + assert len(tgt.patch_calls) == 2 + + patches_by_id = dict(tgt.patch_calls) + src_patch = patches_by_id["tgt-ep-source"] + assert src_patch["connection_type"] == "FILESYSTEM" + assert src_patch["connector_instance_id"] == TGT_CONN + assert src_patch["configuration"]["connector_id"] == TGT_CONN + assert src_patch["configuration"]["path"] == "/in" + + dst_patch = patches_by_id["tgt-ep-dest"] + assert dst_patch["configuration"]["path"] == "/out" + assert dst_patch["connector_instance_id"] == TGT_CONN + + assert ctx.remap.resolve("workflow_endpoint", "src-ep-source") == "tgt-ep-source" + assert ctx.remap.resolve("workflow_endpoint", "src-ep-dest") == "tgt-ep-dest" + + +def test_endpoint_with_null_connection_type_omits_key_in_payload(): + # Source had connection_type=None (rare but legal on the model). + # Must NOT coerce to "" — backend treats blank as a validation + # failure on the enum. Omit the key entirely so backend keeps the + # existing target value. + src = FakeClient() + src.endpoints[SRC_WF] = [ + { + "id": "src-ep-source", + "workflow": SRC_WF, + "endpoint_type": "SOURCE", + "connection_type": None, + "configuration": {}, + "connector_instance": None, + } + ] + tgt = FakeClient() + tgt.endpoints[TGT_WF] = [_tgt_endpoint("tgt-ep-source", "SOURCE")] + ctx = _ctx(src, tgt, remap=_seed_remap()) + + result = WorkflowEndpointPhase(ctx).run(CloneReport()) + + assert result.failed == 0 + assert len(tgt.patch_calls) == 1 + _, payload = tgt.patch_calls[0] + assert "connection_type" not in payload + + +def test_endpoint_without_source_connector_patches_with_null(): + src = FakeClient() + src.endpoints[SRC_WF] = [ + { + "id": "src-ep-source", + "endpoint_type": "SOURCE", + "connection_type": "API", + "configuration": {"foo": "bar"}, + "connector_instance": None, + } + ] + tgt = FakeClient() + tgt.endpoints[TGT_WF] = [_tgt_endpoint("tgt-ep-source", "SOURCE")] + ctx = _ctx(src, tgt, remap=_seed_remap()) + + result = WorkflowEndpointPhase(ctx).run(CloneReport()) + + assert result.created == 1 + assert len(tgt.patch_calls) == 1 + _, payload = tgt.patch_calls[0] + assert payload["connector_instance_id"] is None + assert payload["configuration"] == {"foo": "bar"} + + +def test_unknown_connector_uuid_skips_endpoint_and_flags_error(): + """Source had a connector but its remap is missing — patching with + connector=None would silently detach the endpoint on target. Skip + the PATCH and record an operator-visible error entry instead. + """ + src = FakeClient() + src.endpoints[SRC_WF] = [ + _src_endpoint( + "src-ep-source", + "SOURCE", + "unmapped-but-uuid-99999999-9999-9999-9999-999999999999"[:36], + {}, + ) + ] + tgt = FakeClient() + tgt.endpoints[TGT_WF] = [_tgt_endpoint("tgt-ep-source", "SOURCE")] + ctx = _ctx(src, tgt, remap=_seed_remap()) + + result = WorkflowEndpointPhase(ctx).run(CloneReport()) + + assert result.created == 0 + assert result.skipped == 1 + assert tgt.patch_calls == [] + assert any("unmapped connector" in e for e in result.errors) + + +def test_missing_target_endpoint_fails_loudly(): + src = FakeClient() + src.endpoints[SRC_WF] = [_src_endpoint("src-ep-source", "SOURCE", SRC_CONN, {})] + tgt = FakeClient() + tgt.endpoints[TGT_WF] = [] # No endpoints — anomaly. + ctx = _ctx(src, tgt, remap=_seed_remap()) + + result = WorkflowEndpointPhase(ctx).run(CloneReport()) + + assert result.failed == 1 + assert tgt.patch_calls == [] + + +def test_dry_run_makes_no_patches(): + src = FakeClient() + src.endpoints[SRC_WF] = [_src_endpoint("src-ep-source", "SOURCE", SRC_CONN, {})] + tgt = FakeClient() + tgt.endpoints[TGT_WF] = [_tgt_endpoint("tgt-ep-source", "SOURCE")] + ctx = _ctx(src, tgt, remap=_seed_remap(), dry_run=True) + + result = WorkflowEndpointPhase(ctx).run(CloneReport()) + + assert result.skipped == 1 + assert tgt.patch_calls == [] + + +def test_no_workflows_in_remap_is_noop(): + src = FakeClient() + tgt = FakeClient() + ctx = _ctx(src, tgt, remap=RemapTable()) + + result = WorkflowEndpointPhase(ctx).run(CloneReport()) + + assert result.created == 0 + assert tgt.patch_calls == [] diff --git a/tests/clone/test_workflow_phase.py b/tests/clone/test_workflow_phase.py new file mode 100644 index 0000000..638ed83 --- /dev/null +++ b/tests/clone/test_workflow_phase.py @@ -0,0 +1,155 @@ +"""Tests for ``WorkflowPhase``. + +Coverage: +- happy path: source workflow created on target, connector UUIDs in + ``source_settings`` / ``destination_settings`` rewritten via walker. +- idempotency: re-run on existing target adopts and doesn't duplicate. +- dry-run: no POST. +- abort on name conflict. +""" + +from __future__ import annotations + +import pytest + +from unstract.clone.context import ( + CloneContext, + CloneOptions, + RemapTable, +) +from unstract.clone.exceptions import NameConflictError +from unstract.clone.phases.workflow import WorkflowPhase +from unstract.clone.report import CloneReport + + +WORKFLOW_POST_SCHEMA = frozenset( + { + "workflow_name", + "description", + "is_active", + "deployment_type", + "source_settings", + "destination_settings", + "max_file_execution_count", + "shared_users", + "shared_to_org", + } +) + + +class FakeClient: + def __init__(self, workflows: list[dict] | None = None): + self.workflows: list[dict] = list(workflows or []) + self.posts: list[dict] = [] + self._next_id = 1 + + def get_post_schema(self, entity_path: str) -> frozenset[str]: + return WORKFLOW_POST_SCHEMA + + def list_workflows(self, *, name: str | None = None): + result = self.workflows + if name is not None: + result = [w for w in result if w["workflow_name"] == name] + return list(result) + + def create_workflow(self, payload: dict) -> dict: + new = dict(payload) + new["id"] = f"tgt-{self._next_id:08d}-0000-0000-0000-000000000000" + self._next_id += 1 + self.workflows.append(new) + self.posts.append(new) + return new + + +def _src(id_, name, *, source_settings=None, destination_settings=None): + return { + "id": id_, + "workflow_name": name, + "description": f"{name} desc", + "is_active": True, + "deployment_type": "DEFAULT", + "source_settings": source_settings or {}, + "destination_settings": destination_settings or {}, + "max_file_execution_count": None, + "shared_users": [], + "shared_to_org": False, + } + + +def _ctx(source, target, *, remap=None, **opt_overrides): + return CloneContext( + source=source, + target=target, + options=CloneOptions(**opt_overrides), + remap=remap or RemapTable(), + ) + + +def test_happy_path_creates_workflow_and_remaps_connector_uuids(): + src_conn = "11111111-1111-1111-1111-111111111111" + tgt_conn = "a1111111-1111-1111-1111-111111111111" + src = FakeClient( + [ + _src( + "wf-src-1", + "Invoice ETL", + source_settings={"connector_id": src_conn, "extras": {"a": 1}}, + destination_settings={"connector_id": src_conn}, + ) + ] + ) + tgt = FakeClient() + remap = RemapTable() + remap.record("connector", src_conn, tgt_conn) + ctx = _ctx(src, tgt, remap=remap) + + result = WorkflowPhase(ctx).run(CloneReport()) + + assert result.created == 1 + assert result.failed == 0 + assert len(tgt.posts) == 1 + posted = tgt.posts[0] + # Walker rewrote both occurrences of the source connector UUID. + assert posted["source_settings"]["connector_id"] == tgt_conn + assert posted["destination_settings"]["connector_id"] == tgt_conn + # Unrelated nested data passes through untouched. + assert posted["source_settings"]["extras"] == {"a": 1} + + assert ctx.remap.resolve("workflow", "wf-src-1") == posted["id"] + + +def test_idempotent_rerun_adopts_existing_workflow(): + src = FakeClient([_src("wf-src-1", "Invoice ETL")]) + tgt = FakeClient( + [{"id": "wf-tgt-pre", "workflow_name": "Invoice ETL"}] + ) + ctx = _ctx(src, tgt, on_name_conflict="adopt") + + result = WorkflowPhase(ctx).run(CloneReport()) + + assert result.adopted == 1 + assert result.created == 0 + assert tgt.posts == [] + assert ctx.remap.resolve("workflow", "wf-src-1") == "wf-tgt-pre" + + +def test_dry_run_creates_nothing(): + src = FakeClient([_src("wf-src-1", "Invoice ETL")]) + tgt = FakeClient() + ctx = _ctx(src, tgt, dry_run=True) + + result = WorkflowPhase(ctx).run(CloneReport()) + + assert result.skipped == 1 + assert tgt.posts == [] + + +def test_abort_on_name_conflict_raises(): + src = FakeClient([_src("wf-src-1", "Invoice ETL")]) + tgt = FakeClient( + [{"id": "wf-tgt-pre", "workflow_name": "Invoice ETL"}] + ) + ctx = _ctx(src, tgt, on_name_conflict="abort") + + with pytest.raises(NameConflictError): + WorkflowPhase(ctx).run(CloneReport()) diff --git a/uv.lock b/uv.lock index 25478df..8710285 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 2 +revision = 3 requires-python = ">=3.11" [[package]] @@ -102,6 +102,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/0a/4c/925909008ed5a988ccbb72dcc897407e5d6d3bd72410d69e051fc0c14647/charset_normalizer-3.4.4-py3-none-any.whl", hash = "sha256:7a32c560861a02ff789ad905a2fe94e3f840803362c84fecf1851cb4cf3dc37f", size = 53402, upload-time = "2025-10-14T04:42:31.76Z" }, ] +[[package]] +name = "click" +version = "8.4.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "colorama", marker = "sys_platform == 'win32'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/9b/98/518d8e5081007684232226f475082b30087d0f585e8457db087298259f49/click-8.4.1.tar.gz", hash = "sha256:918b5633eddf6b41c32d4f454bf0de810065c74e3f7dbf8ee5452f8be88d3e96", size = 353007, upload-time = "2026-05-22T04:08:37.769Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c7/0d/67e5b4109ea4a837e80daa87c2c696711955e40449a97e8926672534def2/click-8.4.1-py3-none-any.whl", hash = "sha256:482be17c6991b8c19c5429a1e995d9b0efdbb63172824c41f99965dc0ade8ec2", size = 116639, upload-time = "2026-05-22T04:08:35.26Z" }, +] + [[package]] name = "colorama" version = "0.4.6" @@ -282,6 +294,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/cb/b1/3846dd7f199d53cb17f49cba7e651e9ce294d8497c8c150530ed11865bb8/iniconfig-2.3.0-py3-none-any.whl", hash = "sha256:f631c04d2c48c52b84d0d0549c99ff3859c98df65b3101406327ecc7d53fbf12", size = 7484, upload-time = "2025-10-18T21:55:41.639Z" }, ] +[[package]] +name = "markdown-it-py" +version = "4.2.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "mdurl" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/06/ff/7841249c247aa650a76b9ee4bbaeae59370dc8bfd2f6c01f3630c35eb134/markdown_it_py-4.2.0.tar.gz", hash = "sha256:04a21681d6fbb623de53f6f364d352309d4094dd4194040a10fd51833e418d49", size = 82454, upload-time = "2026-05-07T12:08:28.36Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b3/81/4da04ced5a082363ecfa159c010d200ecbd959ae410c10c0264a38cac0f5/markdown_it_py-4.2.0-py3-none-any.whl", hash = "sha256:9f7ebbcd14fe59494226453aed97c1070d83f8d24b6fc3a3bcf9a38092641c4a", size = 91687, upload-time = "2026-05-07T12:08:27.182Z" }, +] + [[package]] name = "mbstrdecoder" version = "1.1.4" @@ -294,6 +318,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/30/ac/5ce64a1d4cce00390beab88622a290420401f1cabf05caf2fc0995157c21/mbstrdecoder-1.1.4-py3-none-any.whl", hash = "sha256:03dae4ec50ec0d2ff4743e63fdbd5e0022815857494d35224b60775d3d934a8c", size = 7933, upload-time = "2025-01-18T10:07:29.562Z" }, ] +[[package]] +name = "mdurl" +version = "0.1.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/d6/54/cfe61301667036ec958cb99bd3efefba235e65cdeb9c84d24a8293ba1d90/mdurl-0.1.2.tar.gz", hash = "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba", size = 8729, upload-time = "2022-08-14T12:40:10.846Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b3/38/89ba8ad64ae25be8de66a6d463314cf1eb366222074cfda9ee839c56a4b4/mdurl-0.1.2-py3-none-any.whl", hash = "sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8", size = 9979, upload-time = "2022-08-14T12:40:09.779Z" }, +] + [[package]] name = "mypy" version = "1.10.1" @@ -593,6 +626,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/1e/db/4254e3eabe8020b458f1a747140d32277ec7a271daf1d235b70dc0b4e6e3/requests-2.32.5-py3-none-any.whl", hash = "sha256:2462f94637a34fd532264295e186976db0f5d453d1cdd31473c85a6a161affb6", size = 64738, upload-time = "2025-08-18T20:46:00.542Z" }, ] +[[package]] +name = "rich" +version = "15.0.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "markdown-it-py" }, + { name = "pygments" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/c0/8f/0722ca900cc807c13a6a0c696dacf35430f72e0ec571c4275d2371fca3e9/rich-15.0.0.tar.gz", hash = "sha256:edd07a4824c6b40189fb7ac9bc4c52536e9780fbbfbddf6f1e2502c31b068c36", size = 230680, upload-time = "2026-04-12T08:24:00.75Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/82/3b/64d4899d73f91ba49a8c18a8ff3f0ea8f1c1d75481760df8c68ef5235bf5/rich-15.0.0-py3-none-any.whl", hash = "sha256:33bd4ef74232fb73fe9279a257718407f169c09b78a87ad3d296f548e27de0bb", size = 310654, upload-time = "2026-04-12T08:24:02.83Z" }, +] + [[package]] name = "ruff" version = "0.15.1" @@ -757,6 +803,12 @@ dependencies = [ { name = "tenacity" }, ] +[package.optional-dependencies] +clone = [ + { name = "click" }, + { name = "rich" }, +] + [package.dev-dependencies] dev = [ { name = "docutils" }, @@ -789,9 +841,12 @@ test = [ [package.metadata] requires-dist = [ + { name = "click", marker = "extra == 'clone'", specifier = ">=8.1" }, { name = "requests", specifier = ">=2.32.3" }, + { name = "rich", marker = "extra == 'clone'", specifier = ">=13.7" }, { name = "tenacity", specifier = ">=8.2.0" }, ] +provides-extras = ["clone"] [package.metadata.requires-dev] dev = [