From 0cc9f022a79db3133a2a006ee3d8c49dce6a8196 Mon Sep 17 00:00:00 2001 From: Muzi Wong <178915912+MuziWong@users.noreply.github.com> Date: Mon, 29 Jun 2026 16:23:25 +0800 Subject: [PATCH 1/7] feat: add text/image-to-3d-scene pipeline --- embodichain/gen_sim/prompt2scene/.gitignore | 7 + embodichain/gen_sim/prompt2scene/__init__.py | 15 + .../prompt2scene/agent_tools/__init__.py | 1 + .../agent_tools/clients/__init__.py | 31 + .../prompt2scene/agent_tools/clients/base.py | 131 ++++ .../agent_tools/clients/common.py | 139 ++++ .../agent_tools/clients/config.py | 50 ++ .../geometry_generation_client/__init__.py | 49 ++ .../geometry_generation_client/client.py | 213 ++++++ .../geometry_generation_client/parser.py | 255 +++++++ .../geometry_generation_client/schemas.py | 134 ++++ .../image_generation_client/__init__.py | 39 ++ .../clients/image_generation_client/client.py | 117 ++++ .../clients/image_generation_client/parser.py | 65 ++ .../image_generation_client/schemas.py | 72 ++ .../image_segmentation_client/__init__.py | 61 ++ .../image_segmentation_client/client.py | 132 ++++ .../image_segmentation_client/parser.py | 218 ++++++ .../image_segmentation_client/schemas.py | 103 +++ .../image_segmentation_client/utils.py | 322 +++++++++ .../blender_rendering_manager/__init__.py | 31 + .../blender_rendering_manager/manager.py | 175 +++++ .../blender_rendering_manager/schemas.py | 39 ++ .../geometry_generation_manager/__init__.py | 45 ++ .../geometry_generation_manager/manager.py | 209 ++++++ .../geometry_generation_manager/schemas.py | 105 +++ .../managers/geometry_manager/__init__.py | 69 ++ .../managers/geometry_manager/manager.py | 584 ++++++++++++++++ .../geometry_manager/scene_geometry.py | 567 ++++++++++++++++ .../managers/geometry_manager/schemas.py | 201 ++++++ .../image_generation_manager/__init__.py | 35 + .../image_generation_manager/manager.py | 76 +++ .../image_generation_manager/schemas.py | 43 ++ .../managers/image_scene_manager/__init__.py | 29 + .../managers/image_scene_manager/alignment.py | 537 +++++++++++++++ .../managers/image_scene_manager/manifests.py | 212 ++++++ .../managers/image_scene_manager/prompts.py | 106 +++ .../managers/image_scene_manager/schemas.py | 71 ++ .../image_segmentation_manager/__init__.py | 33 + .../image_segmentation_manager/manager.py | 90 +++ .../image_segmentation_manager/schemas.py | 48 ++ .../managers/matplotlib_manager/__init__.py | 43 ++ .../managers/matplotlib_manager/manager.py | 401 +++++++++++ .../managers/matplotlib_manager/schemas.py | 101 +++ .../managers/metric_scale_manager/__init__.py | 37 + .../managers/metric_scale_manager/manager.py | 431 ++++++++++++ .../managers/metric_scale_manager/schemas.py | 73 ++ .../managers/optimization_manager/__init__.py | 37 + .../managers/optimization_manager/manager.py | 633 +++++++++++++++++ .../managers/simready_manager/__init__.py | 35 + .../managers/simready_manager/manager.py | 396 +++++++++++ .../managers/simready_manager/schemas.py | 58 ++ .../managers/simulation_manager/__init__.py | 31 + .../managers/simulation_manager/manager.py | 124 ++++ .../managers/simulation_manager/schemas.py | 42 ++ .../table_clutter_fit_manager/__init__.py | 23 + .../table_clutter_fit_manager/manager.py | 298 ++++++++ .../managers/text_layout_manager/__init__.py | 33 + .../managers/text_layout_manager/layout.py | 383 +++++++++++ .../text_layout_manager/optimization.py | 404 +++++++++++ .../managers/text_layout_manager/settle.py | 429 ++++++++++++ .../agent_tools/servers/__init__.py | 16 + .../agent_tools/tools/__init__.py | 19 + .../agent_tools/tools/gym_export.py | 319 +++++++++ .../tools/image_scene_asset_generation.py | 636 ++++++++++++++++++ .../agent_tools/tools/table_fit_scene.py | 105 +++ .../tools/text_asset_generation.py | 294 ++++++++ .../agent_tools/tools/text_clutter_layout.py | 62 ++ .../tools/text_scene_metric_scale.py | 161 +++++ .../gen_sim/prompt2scene/cli/__init__.py | 19 + embodichain/gen_sim/prompt2scene/cli/start.py | 90 +++ .../prompt2scene/configs/client_config.json | 21 + .../prompt2scene/configs/llm_config.json | 11 + .../gen_sim/prompt2scene/llms/__init__.py | 31 + .../gen_sim/prompt2scene/llms/config.py | 49 ++ .../prompt2scene/llms/openai_compatible.py | 115 ++++ .../gen_sim/prompt2scene/pipeline/__init__.py | 25 + .../gen_sim/prompt2scene/pipeline/runner.py | 239 +++++++ .../gen_sim/prompt2scene/prompts/__init__.py | 48 ++ .../gen_sim/prompt2scene/prompts/base.py | 79 +++ .../prompt2scene/prompts/data/__init__.py | 21 + .../prompts/data/image_relations.yaml | 238 +++++++ .../prompts/data/scene_intake.yaml | 468 +++++++++++++ .../prompts/data/text_relations.yaml | 110 +++ .../prompts/data/unified_scene_gen.yaml | 225 +++++++ .../gen_sim/prompt2scene/utils/__init__.py | 39 ++ embodichain/gen_sim/prompt2scene/utils/io.py | 66 ++ embodichain/gen_sim/prompt2scene/utils/log.py | 62 ++ .../prompt2scene/workflows/__init__.py | 41 ++ .../prompt2scene/workflows/artifact_writer.py | 271 ++++++++ .../prompt2scene/workflows/attempt_state.py | 30 + .../workflows/image_relations/__init__.py | 24 + .../workflows/image_relations/graph.py | 189 ++++++ .../workflows/image_relations/nodes.py | 511 ++++++++++++++ .../workflows/image_relations/prompts.py | 113 ++++ .../workflows/image_relations/schema.py | 250 +++++++ .../workflows/image_relations/state.py | 42 ++ .../workflows/image_relations/utils.py | 435 ++++++++++++ .../prompt2scene/workflows/llm_output.py | 285 ++++++++ .../gen_sim/prompt2scene/workflows/request.py | 110 +++ .../workflows/scene_intake/__init__.py | 24 + .../workflows/scene_intake/graph.py | 142 ++++ .../workflows/scene_intake/nodes.py | 211 ++++++ .../workflows/scene_intake/prompts.py | 197 ++++++ .../workflows/scene_intake/schema.py | 244 +++++++ .../workflows/scene_intake/state.py | 37 + .../workflows/scene_intake/utils.py | 229 +++++++ .../gen_sim/prompt2scene/workflows/spatial.py | 309 +++++++++ .../prompt2scene/workflows/stage_errors.py | 40 ++ .../workflows/text_relations/__init__.py | 24 + .../workflows/text_relations/graph.py | 124 ++++ .../workflows/text_relations/nodes.py | 144 ++++ .../workflows/text_relations/prompts.py | 55 ++ .../workflows/text_relations/schema.py | 164 +++++ .../workflows/text_relations/state.py | 42 ++ .../workflows/text_relations/utils.py | 191 ++++++ .../workflows/unified_scene/__init__.py | 19 + .../workflows/unified_scene/graph.py | 97 +++ .../workflows/unified_scene/nodes.py | 57 ++ .../workflows/unified_scene/schema.py | 157 +++++ .../workflows/unified_scene/state.py | 45 ++ .../workflows/unified_scene/utils.py | 332 +++++++++ .../workflows/unified_scene_gen/__init__.py | 27 + .../workflows/unified_scene_gen/graph.py | 106 +++ .../workflows/unified_scene_gen/nodes.py | 392 +++++++++++ .../workflows/unified_scene_gen/paths.py | 102 +++ .../workflows/unified_scene_gen/prompts.py | 141 ++++ .../unified_scene_gen/scene_update.py | 76 +++ .../workflows/unified_scene_gen/schema.py | 71 ++ .../workflows/unified_scene_gen/state.py | 40 ++ 130 files changed, 19179 insertions(+) create mode 100644 embodichain/gen_sim/prompt2scene/.gitignore create mode 100644 embodichain/gen_sim/prompt2scene/__init__.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/__init__.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/clients/__init__.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/clients/base.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/clients/common.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/clients/config.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/clients/geometry_generation_client/__init__.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/clients/geometry_generation_client/client.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/clients/geometry_generation_client/parser.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/clients/geometry_generation_client/schemas.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/clients/image_generation_client/__init__.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/clients/image_generation_client/client.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/clients/image_generation_client/parser.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/clients/image_generation_client/schemas.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/clients/image_segmentation_client/__init__.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/clients/image_segmentation_client/client.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/clients/image_segmentation_client/parser.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/clients/image_segmentation_client/schemas.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/clients/image_segmentation_client/utils.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/managers/blender_rendering_manager/__init__.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/managers/blender_rendering_manager/manager.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/managers/blender_rendering_manager/schemas.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/managers/geometry_generation_manager/__init__.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/managers/geometry_generation_manager/manager.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/managers/geometry_generation_manager/schemas.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/managers/geometry_manager/__init__.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/managers/geometry_manager/manager.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/managers/geometry_manager/scene_geometry.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/managers/geometry_manager/schemas.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/managers/image_generation_manager/__init__.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/managers/image_generation_manager/manager.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/managers/image_generation_manager/schemas.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/managers/image_scene_manager/__init__.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/managers/image_scene_manager/alignment.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/managers/image_scene_manager/manifests.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/managers/image_scene_manager/prompts.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/managers/image_scene_manager/schemas.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/managers/image_segmentation_manager/__init__.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/managers/image_segmentation_manager/manager.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/managers/image_segmentation_manager/schemas.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/managers/matplotlib_manager/__init__.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/managers/matplotlib_manager/manager.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/managers/matplotlib_manager/schemas.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/managers/metric_scale_manager/__init__.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/managers/metric_scale_manager/manager.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/managers/metric_scale_manager/schemas.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/managers/optimization_manager/__init__.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/managers/optimization_manager/manager.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/managers/simready_manager/__init__.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/managers/simready_manager/manager.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/managers/simready_manager/schemas.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/managers/simulation_manager/__init__.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/managers/simulation_manager/manager.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/managers/simulation_manager/schemas.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/managers/table_clutter_fit_manager/__init__.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/managers/table_clutter_fit_manager/manager.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/managers/text_layout_manager/__init__.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/managers/text_layout_manager/layout.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/managers/text_layout_manager/optimization.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/managers/text_layout_manager/settle.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/servers/__init__.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/tools/__init__.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/tools/gym_export.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/tools/image_scene_asset_generation.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/tools/table_fit_scene.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/tools/text_asset_generation.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/tools/text_clutter_layout.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/tools/text_scene_metric_scale.py create mode 100644 embodichain/gen_sim/prompt2scene/cli/__init__.py create mode 100644 embodichain/gen_sim/prompt2scene/cli/start.py create mode 100644 embodichain/gen_sim/prompt2scene/configs/client_config.json create mode 100644 embodichain/gen_sim/prompt2scene/configs/llm_config.json create mode 100644 embodichain/gen_sim/prompt2scene/llms/__init__.py create mode 100644 embodichain/gen_sim/prompt2scene/llms/config.py create mode 100644 embodichain/gen_sim/prompt2scene/llms/openai_compatible.py create mode 100644 embodichain/gen_sim/prompt2scene/pipeline/__init__.py create mode 100644 embodichain/gen_sim/prompt2scene/pipeline/runner.py create mode 100644 embodichain/gen_sim/prompt2scene/prompts/__init__.py create mode 100644 embodichain/gen_sim/prompt2scene/prompts/base.py create mode 100644 embodichain/gen_sim/prompt2scene/prompts/data/__init__.py create mode 100644 embodichain/gen_sim/prompt2scene/prompts/data/image_relations.yaml create mode 100644 embodichain/gen_sim/prompt2scene/prompts/data/scene_intake.yaml create mode 100644 embodichain/gen_sim/prompt2scene/prompts/data/text_relations.yaml create mode 100644 embodichain/gen_sim/prompt2scene/prompts/data/unified_scene_gen.yaml create mode 100644 embodichain/gen_sim/prompt2scene/utils/__init__.py create mode 100644 embodichain/gen_sim/prompt2scene/utils/io.py create mode 100644 embodichain/gen_sim/prompt2scene/utils/log.py create mode 100644 embodichain/gen_sim/prompt2scene/workflows/__init__.py create mode 100644 embodichain/gen_sim/prompt2scene/workflows/artifact_writer.py create mode 100644 embodichain/gen_sim/prompt2scene/workflows/attempt_state.py create mode 100644 embodichain/gen_sim/prompt2scene/workflows/image_relations/__init__.py create mode 100644 embodichain/gen_sim/prompt2scene/workflows/image_relations/graph.py create mode 100644 embodichain/gen_sim/prompt2scene/workflows/image_relations/nodes.py create mode 100644 embodichain/gen_sim/prompt2scene/workflows/image_relations/prompts.py create mode 100644 embodichain/gen_sim/prompt2scene/workflows/image_relations/schema.py create mode 100644 embodichain/gen_sim/prompt2scene/workflows/image_relations/state.py create mode 100644 embodichain/gen_sim/prompt2scene/workflows/image_relations/utils.py create mode 100644 embodichain/gen_sim/prompt2scene/workflows/llm_output.py create mode 100644 embodichain/gen_sim/prompt2scene/workflows/request.py create mode 100644 embodichain/gen_sim/prompt2scene/workflows/scene_intake/__init__.py create mode 100644 embodichain/gen_sim/prompt2scene/workflows/scene_intake/graph.py create mode 100644 embodichain/gen_sim/prompt2scene/workflows/scene_intake/nodes.py create mode 100644 embodichain/gen_sim/prompt2scene/workflows/scene_intake/prompts.py create mode 100644 embodichain/gen_sim/prompt2scene/workflows/scene_intake/schema.py create mode 100644 embodichain/gen_sim/prompt2scene/workflows/scene_intake/state.py create mode 100644 embodichain/gen_sim/prompt2scene/workflows/scene_intake/utils.py create mode 100644 embodichain/gen_sim/prompt2scene/workflows/spatial.py create mode 100644 embodichain/gen_sim/prompt2scene/workflows/stage_errors.py create mode 100644 embodichain/gen_sim/prompt2scene/workflows/text_relations/__init__.py create mode 100644 embodichain/gen_sim/prompt2scene/workflows/text_relations/graph.py create mode 100644 embodichain/gen_sim/prompt2scene/workflows/text_relations/nodes.py create mode 100644 embodichain/gen_sim/prompt2scene/workflows/text_relations/prompts.py create mode 100644 embodichain/gen_sim/prompt2scene/workflows/text_relations/schema.py create mode 100644 embodichain/gen_sim/prompt2scene/workflows/text_relations/state.py create mode 100644 embodichain/gen_sim/prompt2scene/workflows/text_relations/utils.py create mode 100644 embodichain/gen_sim/prompt2scene/workflows/unified_scene/__init__.py create mode 100644 embodichain/gen_sim/prompt2scene/workflows/unified_scene/graph.py create mode 100644 embodichain/gen_sim/prompt2scene/workflows/unified_scene/nodes.py create mode 100644 embodichain/gen_sim/prompt2scene/workflows/unified_scene/schema.py create mode 100644 embodichain/gen_sim/prompt2scene/workflows/unified_scene/state.py create mode 100644 embodichain/gen_sim/prompt2scene/workflows/unified_scene/utils.py create mode 100644 embodichain/gen_sim/prompt2scene/workflows/unified_scene_gen/__init__.py create mode 100644 embodichain/gen_sim/prompt2scene/workflows/unified_scene_gen/graph.py create mode 100644 embodichain/gen_sim/prompt2scene/workflows/unified_scene_gen/nodes.py create mode 100644 embodichain/gen_sim/prompt2scene/workflows/unified_scene_gen/paths.py create mode 100644 embodichain/gen_sim/prompt2scene/workflows/unified_scene_gen/prompts.py create mode 100644 embodichain/gen_sim/prompt2scene/workflows/unified_scene_gen/scene_update.py create mode 100644 embodichain/gen_sim/prompt2scene/workflows/unified_scene_gen/schema.py create mode 100644 embodichain/gen_sim/prompt2scene/workflows/unified_scene_gen/state.py diff --git a/embodichain/gen_sim/prompt2scene/.gitignore b/embodichain/gen_sim/prompt2scene/.gitignore new file mode 100644 index 000000000..75f4908e8 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/.gitignore @@ -0,0 +1,7 @@ +cli/preview* +cli/export* +agent_tools/servers/geometry_generation_server/* + +# Python cache +__pycache__/ +*.py[cod] diff --git a/embodichain/gen_sim/prompt2scene/__init__.py b/embodichain/gen_sim/prompt2scene/__init__.py new file mode 100644 index 000000000..01ece10d4 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/__init__.py @@ -0,0 +1,15 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- \ No newline at end of file diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/__init__.py b/embodichain/gen_sim/prompt2scene/agent_tools/__init__.py new file mode 100644 index 000000000..a4b11ff06 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/__init__.py @@ -0,0 +1 @@ +"""Internal client + External server for agent tool calling.""" diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/clients/__init__.py b/embodichain/gen_sim/prompt2scene/agent_tools/clients/__init__.py new file mode 100644 index 000000000..3afc32bd0 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/clients/__init__.py @@ -0,0 +1,31 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from embodichain.gen_sim.prompt2scene.agent_tools.clients.base import BaseHttpClient +from embodichain.gen_sim.prompt2scene.agent_tools.clients.common import ClientError +from embodichain.gen_sim.prompt2scene.agent_tools.clients.config import ( + DEFAULT_CLIENT_CONFIG_PATH, + load_client_config, +) + +__all__ = [ + "BaseHttpClient", + "ClientError", + "DEFAULT_CLIENT_CONFIG_PATH", + "load_client_config", +] diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/clients/base.py b/embodichain/gen_sim/prompt2scene/agent_tools/clients/base.py new file mode 100644 index 000000000..8981602f6 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/clients/base.py @@ -0,0 +1,131 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +import time +from pathlib import Path +from typing import Callable + +import requests + +from embodichain.gen_sim.prompt2scene.agent_tools.clients.common import ( + ClientError, + build_client_error, +) +from embodichain.gen_sim.prompt2scene.agent_tools.clients.config import ( + load_client_config, +) +from embodichain.gen_sim.prompt2scene.utils.log import ( + log_api_request_start, + log_info, + log_warning, +) + +__all__ = ["BaseHttpClient"] + + +class BaseHttpClient: + """Shared HTTP client behavior for agent-tool service clients.""" + + def __init__( + self, + *, + config_key: str, + server_name: str, + base_url: str | None = None, + timeout_s: int | None = None, + config_path: Path | None = None, + session: requests.Session | None = None, + trust_env: bool = True, + ) -> None: + """Initialize common service client fields from config.""" + self.config = load_client_config(config_key, config_path) + self.server_name = server_name + self.base_url = (base_url or str(self.config["base_url"])).rstrip("/") + self.timeout_s = int(timeout_s or self.config.get("timeout_s", 120)) + self.health_path = str(self.config.get("health_path", "/health")) + self.session = session or requests.Session() + self.session.trust_env = trust_env + log_info(f"{self.server_name} client initialized for {self.base_url}") + + def health_check(self) -> bool: + """Check whether the configured service is healthy.""" + try: + response = self.session.get( + f"{self.base_url}{self.health_path}", + timeout=5, + ) + response.raise_for_status() + return True + except Exception as exc: + log_warning(f"{self.server_name} health check failed: {exc}") + return False + + def post_with_retries( + self, + request_fn: Callable[[], requests.Response], + *, + max_retries: int, + error_cls: type[ClientError] = ClientError, + request_label: str | None = None, + ) -> requests.Response | ClientError: + """Run a POST request function with retry and HTTP error handling.""" + for attempt in range(max_retries): + try: + if request_label is not None: + log_api_request_start( + step=self.server_name, + request=request_label, + attempt=attempt + 1, + ) + response = request_fn() + response.raise_for_status() + return response + + except requests.exceptions.ConnectionError as exc: + if attempt < max_retries - 1: + log_warning( + f"{self.server_name} connection failed; retrying " + f"({attempt + 1}/{max_retries})." + ) + time.sleep(min(2**attempt, 60)) + continue + raise ConnectionError( + f"Failed to connect to {self.server_name} at {self.base_url}" + ) from exc + + except requests.exceptions.HTTPError as exc: + response = exc.response + if response is None: + raise RuntimeError(f"{self.server_name} HTTP request failed.") from exc + if response.status_code >= 500 and attempt < max_retries - 1: + log_warning( + f"{self.server_name} server error; retrying " + f"({attempt + 1}/{max_retries})." + ) + time.sleep(min(2**attempt, 60)) + continue + return build_client_error( + response, + server_name=self.server_name, + error_cls=error_cls, + ) + + except requests.exceptions.Timeout as exc: + raise TimeoutError(f"{self.server_name} request timed out.") from exc + + raise RuntimeError(f"{self.server_name} request failed unexpectedly.") diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/clients/common.py b/embodichain/gen_sim/prompt2scene/agent_tools/clients/common.py new file mode 100644 index 000000000..f1c7dc690 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/clients/common.py @@ -0,0 +1,139 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any + +import requests + +__all__ = [ + "ClientError", + "build_client_error", + "first_string", + "format_http_error", + "parse_error_response", + "parse_json_object_response", + "validate_required_strings", + "validate_png_response", +] + + +@dataclass(frozen=True) +class ClientError: + """Common HTTP client error response.""" + + error_message: str + status_code: int | None = None + content_type: str | None = None + headers: dict[str, str] = field(default_factory=dict) + raw_response: dict[str, Any] | None = None + + +def validate_png_response( + response: requests.Response, + png_bytes: bytes, +) -> None: + content_type = response.headers.get("Content-Type", "") + if "image/png" not in content_type.lower(): + raise RuntimeError( + "Image generation server returned non-PNG content: " + f"{content_type or 'unknown'}" + ) + if not png_bytes.startswith(b"\x89PNG\r\n\x1a\n"): + raise RuntimeError("Image generation server returned invalid PNG bytes.") + + +def validate_required_strings(fields: dict[str, object]) -> None: + """Validate required client request string fields.""" + for field_name, value in fields.items(): + if not str(value).strip(): + raise ValueError(f"{field_name} must be non-empty.") + + +def format_http_error(response: requests.Response, *, server_name: str) -> str: + """Format an HTTP error response from an agent-tool server.""" + try: + response_data = response.json() + except ValueError: + return f"{server_name} HTTP error: {response.status_code}" + + error_message = first_string( + response_data, + "error", + "error_message", + "message", + "detail", + ) + if error_message: + return f"{server_name} error: {error_message}" + return f"{server_name} HTTP error: {response.status_code}" + + +def parse_error_response(response: requests.Response) -> dict[str, Any] | None: + """Parse an error response body as a JSON object if possible.""" + try: + response_data = response.json() + except ValueError: + return None + return response_data if isinstance(response_data, dict) else None + + +def build_client_error( + response: requests.Response, + *, + server_name: str, + error_cls: type[ClientError] = ClientError, +) -> ClientError: + """Build a common client error dataclass from an HTTP response.""" + return error_cls( + error_message=format_http_error( + response, + server_name=server_name, + ), + status_code=response.status_code, + content_type=response.headers.get("Content-Type"), + headers=dict(response.headers), + raw_response=parse_error_response(response), + ) + + +def parse_json_object_response( + response: requests.Response, + *, + server_name: str, +) -> dict[str, Any]: + """Parse an HTTP response body as a JSON object.""" + try: + response_data = response.json() + except ValueError as exc: + raise RuntimeError( + f"{server_name} returned invalid JSON content: " + f"{response.headers.get('Content-Type') or 'unknown'}" + ) from exc + if not isinstance(response_data, dict): + raise RuntimeError(f"{server_name} response must be a JSON object.") + return response_data + + +def first_string(data: dict[str, Any], *keys: str) -> str | None: + """Return the first string value for the given keys.""" + for key in keys: + value = data.get(key) + if isinstance(value, str): + return value + return None diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/clients/config.py b/embodichain/gen_sim/prompt2scene/agent_tools/clients/config.py new file mode 100644 index 000000000..5592806a8 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/clients/config.py @@ -0,0 +1,50 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +import json +from pathlib import Path +from typing import Any + +__all__ = ["DEFAULT_CLIENT_CONFIG_PATH", "load_client_config"] + +DEFAULT_CLIENT_CONFIG_PATH = ( + Path(__file__).resolve().parents[2] / "configs" / "client_config.json" +) + + +def load_client_config( + config_key: str, + config_path: Path | None = None, +) -> dict[str, Any]: + """Load one agent-tool client config section.""" + resolved_config_path = (config_path or DEFAULT_CLIENT_CONFIG_PATH).resolve() + if not resolved_config_path.is_file(): + raise FileNotFoundError(f"Client config not found: {resolved_config_path}") + + with resolved_config_path.open("r", encoding="utf-8") as f: + raw_config = json.load(f) + + config = raw_config.get(config_key) + if not isinstance(config, dict): + raise ValueError( + f"Client config section {config_key!r} not found in " + f"{resolved_config_path}" + ) + if not config.get("base_url"): + raise ValueError(f"Client config section {config_key!r} requires base_url.") + return config diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/clients/geometry_generation_client/__init__.py b/embodichain/gen_sim/prompt2scene/agent_tools/clients/geometry_generation_client/__init__.py new file mode 100644 index 000000000..3fa63f3b1 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/clients/geometry_generation_client/__init__.py @@ -0,0 +1,49 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from embodichain.gen_sim.prompt2scene.agent_tools.clients.geometry_generation_client.client import ( + GeometryGenerationClient, +) +from embodichain.gen_sim.prompt2scene.agent_tools.clients.geometry_generation_client.schemas import ( + GeometryGenerationError, + GeometryGenerationResult, + GeometryGenerationServerRequest, + GeometryGenerationServerResponse, + MultiObjectGenerationError, + MultiObjectGenerationObject, + MultiObjectGenerationResult, + MultiObjectGenerationServerRequest, + MultiObjectGenerationServerResponse, +) +from embodichain.gen_sim.prompt2scene.agent_tools.clients.config import ( + DEFAULT_CLIENT_CONFIG_PATH, +) + +__all__ = [ + "DEFAULT_CLIENT_CONFIG_PATH", + "GeometryGenerationClient", + "GeometryGenerationError", + "GeometryGenerationResult", + "GeometryGenerationServerRequest", + "GeometryGenerationServerResponse", + "MultiObjectGenerationError", + "MultiObjectGenerationObject", + "MultiObjectGenerationResult", + "MultiObjectGenerationServerRequest", + "MultiObjectGenerationServerResponse", +] diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/clients/geometry_generation_client/client.py b/embodichain/gen_sim/prompt2scene/agent_tools/clients/geometry_generation_client/client.py new file mode 100644 index 000000000..0615c6d27 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/clients/geometry_generation_client/client.py @@ -0,0 +1,213 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +"""Client for the SAM3D geometry generation server.""" + +from __future__ import annotations + +from pathlib import Path +from typing import Any + +import requests + +from embodichain.gen_sim.prompt2scene.agent_tools.clients.base import BaseHttpClient +from embodichain.gen_sim.prompt2scene.agent_tools.clients.common import ( + validate_required_strings, +) +from embodichain.gen_sim.prompt2scene.agent_tools.clients.config import ( + DEFAULT_CLIENT_CONFIG_PATH, +) +from embodichain.gen_sim.prompt2scene.agent_tools.clients.geometry_generation_client.parser import ( + parse_geometry_generation_response, + parse_multi_object_generation_response, +) +from embodichain.gen_sim.prompt2scene.agent_tools.clients.geometry_generation_client.schemas import ( + GeometryGenerationError, + GeometryGenerationResult, + GeometryGenerationServerRequest, + GeometryGenerationServerResponse, + MultiObjectGenerationError, + MultiObjectGenerationObject, + MultiObjectGenerationServerRequest, + MultiObjectGenerationServerResponse, +) + +__all__ = [ + "DEFAULT_CLIENT_CONFIG_PATH", + "GeometryGenerationClient", + "GeometryGenerationError", + "GeometryGenerationResult", + "GeometryGenerationServerRequest", + "GeometryGenerationServerResponse", + "MultiObjectGenerationError", + "MultiObjectGenerationObject", + "MultiObjectGenerationServerRequest", + "MultiObjectGenerationServerResponse", +] + + +class GeometryGenerationClient(BaseHttpClient): + """Client for making single-object SAM3D geometry generation requests.""" + + def __init__( + self, + *, + base_url: str | None = None, + timeout_s: int | None = None, + config_path: Path | None = None, + config_key: str = "sam3d_generation", + session: requests.Session | None = None, + ) -> None: + """Initialize the geometry generation client.""" + super().__init__( + config_key=config_key, + server_name="Geometry generation server", + base_url=base_url, + timeout_s=timeout_s, + config_path=config_path, + session=session, + trust_env=False, + ) + self.generate_single_object_path = str( + self.config.get("generate_single_object_path", "/generate_single_object") + ) + self.generate_multiple_objects_path = str( + self.config.get( + "generate_multiple_objects_path", "/generate_multiple_objects" + ) + ) + + def generate( + self, + request: GeometryGenerationServerRequest, + *, + max_retries: int = 3, + ) -> GeometryGenerationServerResponse | GeometryGenerationError: + """Generate one GLB mesh from an object image and save it locally.""" + _validate_request(request) + url = f"{self.base_url}{self.generate_single_object_path}" + response = self.post_with_retries( + lambda: _post_geometry_generation_request(self, url, request), + max_retries=max_retries, + error_cls=GeometryGenerationError, + request_label="geometry_generation", + ) + if isinstance(response, GeometryGenerationError): + return response + return parse_geometry_generation_response(response, request) + + def generate_multiple_objects( + self, + request: MultiObjectGenerationServerRequest, + *, + output_dir: Path | None = None, + max_retries: int = 3, + ) -> MultiObjectGenerationServerResponse | MultiObjectGenerationError: + """Generate multiple GLB meshes from one image and multiple masks.""" + _validate_multi_object_request(request) + url = f"{self.base_url}{self.generate_multiple_objects_path}" + response = self.post_with_retries( + lambda: _post_multi_object_generation_request(self, url, request), + max_retries=max_retries, + error_cls=MultiObjectGenerationError, + request_label="multi_object_geometry_generation", + ) + if isinstance(response, MultiObjectGenerationError): + return response + return parse_multi_object_generation_response( + response, + self.base_url, + output_dir=output_dir, + session=self.session, + ) + + +def _validate_request(request: GeometryGenerationServerRequest) -> None: + validate_required_strings( + { + "Geometry generation image_path": request.image_path, + "Geometry generation output_path": request.output_path, + } + ) + image_path = Path(request.image_path).expanduser() + if not image_path.is_file(): + raise FileNotFoundError(f"Geometry generation input not found: {image_path}") + if not str(request.output_path).lower().endswith(".glb"): + raise ValueError("Geometry generation output_path must be a GLB file path.") + + +def _post_geometry_generation_request( + client: GeometryGenerationClient, + url: str, + request: GeometryGenerationServerRequest, +) -> requests.Response: + with _open_image_file(request.image_path) as image_file: + return client.session.post( + url, + data=request.to_form_data(), + files={ + "image": ( + Path(request.image_path).name, + image_file, + ) + }, + timeout=(10, client.timeout_s), + ) + + +def _open_image_file(image_path: str | Path) -> Any: + return Path(image_path).expanduser().resolve().open("rb") + + +def _validate_multi_object_request( + request: MultiObjectGenerationServerRequest, +) -> None: + validate_required_strings( + {"Multi-object geometry generation image_path": request.image_path} + ) + image_path = Path(request.image_path).expanduser() + if not image_path.is_file(): + raise FileNotFoundError( + f"Multi-object geometry generation input not found: {image_path}" + ) + if not request.mask_paths: + raise ValueError("mask_paths must be non-empty.") + for mask_path in request.mask_paths: + if not Path(mask_path).expanduser().is_file(): + raise FileNotFoundError( + f"Multi-object geometry mask not found: {mask_path}" + ) + + +def _post_multi_object_generation_request( + client: GeometryGenerationClient, + url: str, + request: MultiObjectGenerationServerRequest, +) -> requests.Response: + mask_files = [ + ("masks", (Path(p).name, Path(p).expanduser().resolve().open("rb"))) + for p in request.mask_paths + ] + try: + return client.session.post( + url, + data=request.to_form_data(), + files=[("image", (Path(request.image_path).name, _open_image_file(request.image_path)))] + mask_files, + timeout=(10, client.timeout_s), + ) + finally: + for _, (_, f) in mask_files: + f.close() diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/clients/geometry_generation_client/parser.py b/embodichain/gen_sim/prompt2scene/agent_tools/clients/geometry_generation_client/parser.py new file mode 100644 index 000000000..4d3c09671 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/clients/geometry_generation_client/parser.py @@ -0,0 +1,255 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from pathlib import Path + +import requests + +from embodichain.gen_sim.prompt2scene.agent_tools.clients.geometry_generation_client.schemas import ( + GeometryGenerationResult, + GeometryGenerationServerRequest, + GeometryGenerationServerResponse, + MultiObjectGenerationObject, + MultiObjectGenerationResult, + MultiObjectGenerationServerResponse, +) +from embodichain.gen_sim.prompt2scene.utils.log import log_info + +__all__ = ["parse_geometry_generation_response", "parse_multi_object_generation_response"] + + +def parse_geometry_generation_response( + response: requests.Response, + request: GeometryGenerationServerRequest, +) -> GeometryGenerationServerResponse: + """Parse a geometry GLB response and save it to the request output path.""" + glb_bytes = response.content + _validate_glb_response(response, glb_bytes) + output_path = _write_glb_output(request, glb_bytes) + result = GeometryGenerationResult(geometry_path=str(output_path)) + return GeometryGenerationServerResponse( + ok=True, + status="ok", + result=result, + status_code=response.status_code, + content_type=response.headers.get("Content-Type"), + headers=dict(response.headers), + ) + + +def _validate_glb_response( + response: requests.Response, + glb_bytes: bytes, +) -> None: + if not glb_bytes.startswith(b"glTF"): + content_type = response.headers.get("Content-Type", "") + raise RuntimeError( + "Geometry generation server returned invalid GLB content: " + f"{content_type or 'unknown'}" + ) + + +def _write_glb_output( + request: GeometryGenerationServerRequest, + glb_bytes: bytes, +) -> Path: + output_path = Path(request.output_path).expanduser().resolve() + output_path.parent.mkdir(parents=True, exist_ok=True) + output_path.write_bytes(glb_bytes) + if not output_path.is_file(): + raise FileNotFoundError(f"Generated geometry was not written: {output_path}") + log_info(f"Generated geometry written: {output_path}") + return output_path + + +def parse_multi_object_generation_response( + response: requests.Response, + base_url: str, + *, + output_dir: Path | None = None, + session: requests.Session | None = None, +) -> MultiObjectGenerationServerResponse: + """Parse a multi-object geometry response, download GLBs if output_dir given.""" + body = _parse_json_body(response) + ok = body.get("ok", False) + if not isinstance(ok, bool) or not ok: + error_msg = body.get("error", "ok is not true") + raise RuntimeError( + f"Multi-object geometry generation failed: {error_msg}" + ) + + result_data = body.get("result") + if not isinstance(result_data, dict): + raise RuntimeError( + "Multi-object geometry generation response missing 'result' object" + ) + base = base_url.rstrip("/") + objects = _parse_multi_object_items( + result_data, + base, + output_dir=output_dir, + session=session, + ) + + return MultiObjectGenerationServerResponse( + ok=True, + status=str(body.get("status") or "ok"), + result=MultiObjectGenerationResult(objects=objects), + status_code=response.status_code, + content_type=response.headers.get("Content-Type"), + headers=dict(response.headers), + ) + + +def _parse_multi_object_items( + body: dict[str, object], + base_url: str, + *, + output_dir: Path | None, + session: requests.Session | None, +) -> list[MultiObjectGenerationObject]: + response_objects = body.get("objects") + if not isinstance(response_objects, list) or not response_objects: + raise RuntimeError( + "Multi-object geometry generation response missing 'result.objects' list" + ) + return [ + _parse_multi_object_item( + item, + index=i, + base_url=base_url, + output_dir=output_dir, + session=session, + ) + for i, item in enumerate(response_objects) + ] + + +def _parse_multi_object_item( + item: object, + *, + index: int, + base_url: str, + output_dir: Path | None, + session: requests.Session | None, +) -> MultiObjectGenerationObject: + if not isinstance(item, dict): + raise RuntimeError(f"Multi-object item {index} must be a JSON object") + + mesh_rel_path = item.get("mesh") + if not isinstance(mesh_rel_path, str) or not mesh_rel_path: + raise RuntimeError(f"Multi-object item {index} missing 'mesh'") + + name = str(item.get("name") or Path(mesh_rel_path).stem or index) + geometry_path = _resolve_or_download_glb( + base_url, + mesh_rel_path, + name=name, + index=index, + output_dir=output_dir, + session=session, + ) + + return MultiObjectGenerationObject( + name=name, + geometry_path=geometry_path, + rotation_quaternion_wxyz=_float_list( + item.get("rotation_quaternion_wxyz"), + expected_len=4, + field_name=f"objects[{index}].rotation_quaternion_wxyz", + ), + translation=_float_list( + item.get("translation"), + expected_len=3, + field_name=f"objects[{index}].translation", + ), + scale=_float_list( + item.get("scale"), + expected_len=3, + field_name=f"objects[{index}].scale", + ), + ) + + +def _resolve_or_download_glb( + base_url: str, + mesh_rel_path: str, + *, + name: str, + index: int, + output_dir: Path | None, + session: requests.Session | None, +) -> str: + url = _join_url(base_url, mesh_rel_path) + if output_dir is None: + return url + + output_dir = output_dir.expanduser().resolve() + output_dir.mkdir(parents=True, exist_ok=True) + filename = f"{name}.glb" if name else f"{index}.glb" + dest = output_dir / filename + _download_glb(url, dest, session=session) + return str(dest) + + +def _join_url(base_url: str, path_or_url: str) -> str: + if path_or_url.startswith(("http://", "https://")): + return path_or_url + if path_or_url.startswith("/"): + return f"{base_url}{path_or_url}" + return f"{base_url}/{path_or_url}" + + +def _float_list(value: object, *, expected_len: int, field_name: str) -> list[float]: + if not isinstance(value, list) or len(value) != expected_len: + raise RuntimeError(f"Multi-object geometry response missing '{field_name}'") + try: + return [float(v) for v in value] + except (TypeError, ValueError) as exc: + raise RuntimeError( + f"Multi-object geometry response field '{field_name}' must be numeric" + ) from exc + + +def _parse_json_body(response: requests.Response) -> dict[str, object]: + try: + body = response.json() + except ValueError as exc: + raise RuntimeError( + "Multi-object geometry generation server returned invalid JSON" + ) from exc + if not isinstance(body, dict): + raise RuntimeError( + "Multi-object geometry generation response must be a JSON object" + ) + return body + + +def _download_glb( + url: str, + dest: Path, + *, + session: requests.Session | None, +) -> None: + """Download a GLB from the geometry server.""" + http = session or requests.Session() + r = http.get(url, timeout=30) + r.raise_for_status() + _validate_glb_response(r, r.content) + dest.write_bytes(r.content) + log_info(f"Generated geometry written: {dest}") diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/clients/geometry_generation_client/schemas.py b/embodichain/gen_sim/prompt2scene/agent_tools/clients/geometry_generation_client/schemas.py new file mode 100644 index 000000000..d8ede9eea --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/clients/geometry_generation_client/schemas.py @@ -0,0 +1,134 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from dataclasses import dataclass, field +from pathlib import Path + +from embodichain.gen_sim.prompt2scene.agent_tools.clients.common import ClientError + +__all__ = [ + "GeometryGenerationError", + "GeometryGenerationResult", + "GeometryGenerationServerRequest", + "GeometryGenerationServerResponse", + "MultiObjectGenerationError", + "MultiObjectGenerationObject", + "MultiObjectGenerationResult", + "MultiObjectGenerationServerRequest", + "MultiObjectGenerationServerResponse", +] + + +@dataclass(frozen=True) +class GeometryGenerationServerRequest: + """Request sent to the Geometry Generation server. + + Args: + image_path: Local object image path. + output_path: Local output GLB path where the client saves the generated geometry. + """ + + image_path: str | Path + output_path: str | Path + + def to_form_data(self) -> dict[str, str]: + """Convert the request to the geometry server multipart form fields.""" + return {} + + +@dataclass(frozen=True) +class GeometryGenerationResult: + """Successful Geometry Generation result.""" + + geometry_path: str + + +@dataclass(frozen=True) +class GeometryGenerationServerResponse: + """Parsed successful response from the Geometry Generation server.""" + + ok: bool + result: GeometryGenerationResult + status: str | None = None + error: str | None = None + status_code: int | None = None + content_type: str | None = None + headers: dict[str, str] = field(default_factory=dict) + + +@dataclass(frozen=True) +class GeometryGenerationError(ClientError): + """Geometry generation failure returned by the server.""" + + +@dataclass(frozen=True) +class MultiObjectGenerationServerRequest: + """Request sent to the Geometry Generation server (multi-object). + + Args: + image_path: Local scene RGB image path. + mask_paths: Local mask PNG file paths (one per object). + """ + + image_path: str | Path + mask_paths: list[Path] + + def to_form_data(self) -> dict[str, str]: + """Convert the request to the geometry server multipart form fields.""" + return {"json": "1"} + + +@dataclass(frozen=True) +class MultiObjectGenerationObject: + """Successful Multi-Object Geometry Generation result.""" + + name: str + geometry_path: str + rotation_quaternion_wxyz: list[float] + translation: list[float] + scale: list[float] + + +@dataclass(frozen=True) +class MultiObjectGenerationResult: + """Successful Multi-Object Geometry Generation result.""" + + objects: list[MultiObjectGenerationObject] + + @property + def geometry_paths(self) -> list[str]: + """Paths to the generated GLB files.""" + return [item.geometry_path for item in self.objects] + + +@dataclass(frozen=True) +class MultiObjectGenerationServerResponse: + """Parsed successful response from the Geometry Generation server.""" + + ok: bool + result: MultiObjectGenerationResult + status: str | None = None + error: str | None = None + status_code: int | None = None + content_type: str | None = None + headers: dict[str, str] = field(default_factory=dict) + + +@dataclass(frozen=True) +class MultiObjectGenerationError(ClientError): + """Multi-object geometry generation failure returned by the server.""" diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/clients/image_generation_client/__init__.py b/embodichain/gen_sim/prompt2scene/agent_tools/clients/image_generation_client/__init__.py new file mode 100644 index 000000000..c112bd3d7 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/clients/image_generation_client/__init__.py @@ -0,0 +1,39 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from embodichain.gen_sim.prompt2scene.agent_tools.clients.image_generation_client.client import ( + ImageGenerationClient, +) +from embodichain.gen_sim.prompt2scene.agent_tools.clients.image_generation_client.schemas import ( + ImageGenerationError, + ImageGenerationResult, + ImageGenerationServerRequest, + ImageGenerationServerResponse, +) +from embodichain.gen_sim.prompt2scene.agent_tools.clients.config import ( + DEFAULT_CLIENT_CONFIG_PATH, +) + +__all__ = [ + "DEFAULT_CLIENT_CONFIG_PATH", + "ImageGenerationClient", + "ImageGenerationError", + "ImageGenerationResult", + "ImageGenerationServerRequest", + "ImageGenerationServerResponse", +] diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/clients/image_generation_client/client.py b/embodichain/gen_sim/prompt2scene/agent_tools/clients/image_generation_client/client.py new file mode 100644 index 000000000..6f23d47bd --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/clients/image_generation_client/client.py @@ -0,0 +1,117 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +"""Client for the Z-Image image generation server.""" + +from __future__ import annotations + +from pathlib import Path + +import requests + +from embodichain.gen_sim.prompt2scene.agent_tools.clients.base import BaseHttpClient +from embodichain.gen_sim.prompt2scene.agent_tools.clients.common import ( + validate_required_strings, +) +from embodichain.gen_sim.prompt2scene.agent_tools.clients.config import ( + DEFAULT_CLIENT_CONFIG_PATH, +) +from embodichain.gen_sim.prompt2scene.agent_tools.clients.image_generation_client.parser import ( + parse_generation_response, +) +from embodichain.gen_sim.prompt2scene.agent_tools.clients.image_generation_client.schemas import ( + ImageGenerationError, + ImageGenerationResult, + ImageGenerationServerRequest, + ImageGenerationServerResponse, +) + +__all__ = [ + "DEFAULT_CLIENT_CONFIG_PATH", + "ImageGenerationClient", + "ImageGenerationError", + "ImageGenerationResult", + "ImageGenerationServerRequest", + "ImageGenerationServerResponse", +] + + +class ImageGenerationClient(BaseHttpClient): + """Client for making single-image Z-Image generation requests.""" + + def __init__( + self, + *, + base_url: str | None = None, + timeout_s: int | None = None, + config_path: Path | None = None, + config_key: str = "zimage", + session: requests.Session | None = None, + ) -> None: + """Initialize the image generation client.""" + super().__init__( + config_key=config_key, + server_name="Image generation server", + base_url=base_url, + timeout_s=timeout_s, + config_path=config_path, + session=session, + ) + self.generate_single_object_path = str( + self.config.get("generate_single_object_path", "/generate.png") + ) + + def generate( + self, + request: ImageGenerationServerRequest, + *, + max_retries: int = 3, + ) -> ImageGenerationServerResponse | ImageGenerationError: + """Generate one image and save the returned PNG locally.""" + _validate_request(request) + url = f"{self.base_url}{self.generate_single_object_path}" + response = self.post_with_retries( + lambda: _post_generation_request(self, url, request), + max_retries=max_retries, + error_cls=ImageGenerationError, + request_label="image_generation", + ) + if isinstance(response, ImageGenerationError): + return response + return parse_generation_response(response, request) + + +def _validate_request(request: ImageGenerationServerRequest) -> None: + validate_required_strings( + { + "Image generation prompt": request.prompt, + "Image generation output_path": request.output_path, + } + ) + if not str(request.output_path).lower().endswith(".png"): + raise ValueError("Image generation output_path must be a PNG file path.") + + +def _post_generation_request( + client: ImageGenerationClient, + url: str, + request: ImageGenerationServerRequest, +) -> requests.Response: + return client.session.post( + url, + json=request.to_dict(), + timeout=(10, client.timeout_s), + ) diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/clients/image_generation_client/parser.py b/embodichain/gen_sim/prompt2scene/agent_tools/clients/image_generation_client/parser.py new file mode 100644 index 000000000..a43ee0307 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/clients/image_generation_client/parser.py @@ -0,0 +1,65 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from pathlib import Path + +import requests + +from embodichain.gen_sim.prompt2scene.agent_tools.clients.common import ( + validate_png_response, +) +from embodichain.gen_sim.prompt2scene.agent_tools.clients.image_generation_client.schemas import ( + ImageGenerationResult, + ImageGenerationServerRequest, + ImageGenerationServerResponse, +) +from embodichain.gen_sim.prompt2scene.utils.log import log_info + +__all__ = ["parse_generation_response"] + + +def parse_generation_response( + response: requests.Response, + request: ImageGenerationServerRequest, +) -> ImageGenerationServerResponse: + """Parse a Z-Image PNG response and save it to the request output path.""" + png_bytes = response.content + validate_png_response(response, png_bytes) + output_path = _write_png_output(request, png_bytes) + result = ImageGenerationResult(image_path=str(output_path)) + return ImageGenerationServerResponse( + ok=True, + status="ok", + result=result, + status_code=response.status_code, + content_type=response.headers.get("Content-Type"), + headers=dict(response.headers), + ) + + +def _write_png_output( + request: ImageGenerationServerRequest, + png_bytes: bytes, +) -> Path: + output_path = Path(request.output_path).expanduser().resolve() + output_path.parent.mkdir(parents=True, exist_ok=True) + output_path.write_bytes(png_bytes) + if not output_path.is_file(): + raise FileNotFoundError(f"Generated image was not written: {output_path}") + log_info(f"Generated image written: {output_path}") + return output_path diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/clients/image_generation_client/schemas.py b/embodichain/gen_sim/prompt2scene/agent_tools/clients/image_generation_client/schemas.py new file mode 100644 index 000000000..09c845bac --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/clients/image_generation_client/schemas.py @@ -0,0 +1,72 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any + +from embodichain.gen_sim.prompt2scene.agent_tools.clients.common import ClientError + +__all__ = [ + "ImageGenerationError", + "ImageGenerationResult", + "ImageGenerationServerRequest", + "ImageGenerationServerResponse", +] + + +@dataclass(frozen=True) +class ImageGenerationServerRequest: + """Request sent to the Z-Image server. + + Args: + prompt: Text prompt used to generate the image. + output_path: Local output PNG path where the client saves the response. + """ + + prompt: str + output_path: str | Path + + def to_dict(self) -> dict[str, Any]: + """Convert the request to the Z-Image server JSON payload.""" + return {"prompt": self.prompt} + + +@dataclass(frozen=True) +class ImageGenerationResult: + """Successful Z-Image generation result.""" + + image_path: str + + +@dataclass(frozen=True) +class ImageGenerationServerResponse: + """Parsed successful response from the Z-Image server.""" + + ok: bool + result: ImageGenerationResult + status: str | None = None + error: str | None = None + status_code: int | None = None + content_type: str | None = None + headers: dict[str, str] = field(default_factory=dict) + + +@dataclass(frozen=True) +class ImageGenerationError(ClientError): + """Image generation failure returned by the server.""" diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/clients/image_segmentation_client/__init__.py b/embodichain/gen_sim/prompt2scene/agent_tools/clients/image_segmentation_client/__init__.py new file mode 100644 index 000000000..a503f2875 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/clients/image_segmentation_client/__init__.py @@ -0,0 +1,61 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from embodichain.gen_sim.prompt2scene.agent_tools.clients.config import ( + DEFAULT_CLIENT_CONFIG_PATH, +) +from embodichain.gen_sim.prompt2scene.agent_tools.clients.image_segmentation_client.client import ( + ImageSegmentationClient, +) +from embodichain.gen_sim.prompt2scene.agent_tools.clients.image_segmentation_client.schemas import ( + ImageSegmentationCandidate, + ImageSegmentationError, + ImageSegmentationResult, + ImageSegmentationServerRequest, + ImageSegmentationServerResponse, +) +from embodichain.gen_sim.prompt2scene.agent_tools.clients.image_segmentation_client.utils import ( + apply_mask_to_alpha, + bbox_iou, + decode_rle_mask, + draw_labeled_bboxes, + draw_numbered_bboxes, + draw_numbered_masks, + is_usable_segmentation_candidate, + save_candidate_rgba_and_mask, + sort_segments_by_bbox, +) + +__all__ = [ + "DEFAULT_CLIENT_CONFIG_PATH", + "ImageSegmentationCandidate", + "ImageSegmentationClient", + "ImageSegmentationError", + "ImageSegmentationResult", + "ImageSegmentationServerRequest", + "ImageSegmentationServerResponse", + "apply_mask_to_alpha", + "bbox_iou", + "decode_rle_mask", + "draw_labeled_bboxes", + "draw_numbered_bboxes", + "draw_numbered_masks", + "is_usable_segmentation_candidate", + "save_candidate_rgba_and_mask", + "sort_segments_by_bbox", +] diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/clients/image_segmentation_client/client.py b/embodichain/gen_sim/prompt2scene/agent_tools/clients/image_segmentation_client/client.py new file mode 100644 index 000000000..1a880bb62 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/clients/image_segmentation_client/client.py @@ -0,0 +1,132 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +"""Client for the SAM3 image segmentation server.""" + +from __future__ import annotations + +from pathlib import Path +from typing import Any + +import requests + +from embodichain.gen_sim.prompt2scene.agent_tools.clients.base import BaseHttpClient +from embodichain.gen_sim.prompt2scene.agent_tools.clients.common import ( + validate_required_strings, +) +from embodichain.gen_sim.prompt2scene.agent_tools.clients.config import ( + DEFAULT_CLIENT_CONFIG_PATH, +) +from embodichain.gen_sim.prompt2scene.agent_tools.clients.image_segmentation_client.parser import ( + parse_segmentation_response, +) +from embodichain.gen_sim.prompt2scene.agent_tools.clients.image_segmentation_client.schemas import ( + ImageSegmentationCandidate, + ImageSegmentationError, + ImageSegmentationResult, + ImageSegmentationServerRequest, + ImageSegmentationServerResponse, +) + +__all__ = [ + "DEFAULT_CLIENT_CONFIG_PATH", + "ImageSegmentationCandidate", + "ImageSegmentationClient", + "ImageSegmentationError", + "ImageSegmentationResult", + "ImageSegmentationServerRequest", + "ImageSegmentationServerResponse", +] + + +class ImageSegmentationClient(BaseHttpClient): + """Client for making single-image SAM3 segmentation requests.""" + + def __init__( + self, + *, + base_url: str | None = None, + timeout_s: int | None = None, + config_path: Path | None = None, + config_key: str = "sam3_segmentation", + session: requests.Session | None = None, + ) -> None: + """Initialize the image segmentation client.""" + super().__init__( + config_key=config_key, + server_name="Image segmentation server", + base_url=base_url, + timeout_s=timeout_s, + config_path=config_path, + session=session, + trust_env=False, + ) + self.segmentation_path = str( + self.config.get("segment_single_object_path", "/segment_single_object") + ) + + def segment( + self, + request: ImageSegmentationServerRequest, + *, + max_retries: int = 3, + ) -> ImageSegmentationServerResponse | ImageSegmentationError: + """Segment one image with a text prompt.""" + _validate_request(request) + url = f"{self.base_url}{self.segmentation_path}" + response = self.post_with_retries( + lambda: _post_segmentation_request(self, url, request), + max_retries=max_retries, + error_cls=ImageSegmentationError, + request_label="image_segmentation", + ) + if isinstance(response, ImageSegmentationError): + return response + return parse_segmentation_response(response, request) + + +def _validate_request(request: ImageSegmentationServerRequest) -> None: + validate_required_strings( + { + "Image segmentation image_path": request.image_path, + } + ) + image_path = Path(request.image_path).expanduser() + if not image_path.is_file(): + raise FileNotFoundError(f"Image segmentation input not found: {image_path}") + + +def _post_segmentation_request( + client: ImageSegmentationClient, + url: str, + request: ImageSegmentationServerRequest, +) -> requests.Response: + with _open_image_file(request.image_path) as image_file: + return client.session.post( + url, + data=request.to_form_data(), + files={ + "image": ( + Path(request.image_path).name, + image_file, + ) + }, + timeout=(10, client.timeout_s), + ) + + +def _open_image_file(image_path: str | Path) -> Any: + return Path(image_path).expanduser().resolve().open("rb") diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/clients/image_segmentation_client/parser.py b/embodichain/gen_sim/prompt2scene/agent_tools/clients/image_segmentation_client/parser.py new file mode 100644 index 000000000..762a1b43c --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/clients/image_segmentation_client/parser.py @@ -0,0 +1,218 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from typing import Any + +import requests + +from embodichain.gen_sim.prompt2scene.agent_tools.clients.common import ( + parse_json_object_response, +) +from embodichain.gen_sim.prompt2scene.agent_tools.clients.image_segmentation_client.schemas import ( + ImageSegmentationCandidate, + ImageSegmentationResult, + ImageSegmentationServerRequest, + ImageSegmentationServerResponse, +) + +__all__ = ["parse_segmentation_response"] + +SERVER_NAME = "Image segmentation server" + + +def parse_segmentation_response( + response: requests.Response, + request: ImageSegmentationServerRequest, +) -> ImageSegmentationServerResponse: + """Parse a SAM3 server JSON response into typed segmentation records.""" + response_data = parse_json_object_response( + response, + server_name=SERVER_NAME, + ) + result = _parse_segmentation_result(response_data, request) + return ImageSegmentationServerResponse( + ok=bool(response_data.get("ok", True)), + status=_string_or_none(response_data.get("status")) or "ok", + result=result, + status_code=response.status_code, + content_type=response.headers.get("Content-Type"), + headers=dict(response.headers), + ) + + +def _parse_segmentation_result( + response_data: dict[str, Any], + request: ImageSegmentationServerRequest, +) -> ImageSegmentationResult: + result_data = response_data.get("result") + if not isinstance(result_data, dict): + result_data = response_data.get("data") + if not isinstance(result_data, dict): + result_data = response_data + + return ImageSegmentationResult( + image_path=_string_or_none(result_data.get("image_path")) + or str(request.image_path), + prompt=_string_or_none(result_data.get("prompt")) or request.prompt, + candidates=_parse_candidates(result_data), + request_id=_string_or_none(result_data.get("request_id")), + elapsed_sec=_float_or_none(result_data.get("elapsed_sec")), + count=_int_or_none(result_data.get("count")), + image_width=_parse_image_width(result_data), + image_height=_parse_image_height(result_data), + box_format=_string_or_none(result_data.get("box_format")) or "xyxy", + mask_format=_string_or_none(result_data.get("mask_format")) or "rle", + ) + + +def _parse_candidates(result_data: dict[str, Any]) -> list[ImageSegmentationCandidate]: + for key in ("instances", "candidates", "segmentations", "detections"): + items = result_data.get(key) + if isinstance(items, list): + return [ + _parse_candidate_item(item, index) + for index, item in enumerate(items) + if isinstance(item, dict) + ] + + boxes = result_data.get("boxes", []) + scores = result_data.get("scores", []) + masks = result_data.get("masks", []) + if not isinstance(boxes, list): + return [] + + candidates: list[ImageSegmentationCandidate] = [] + for index, box in enumerate(boxes): + candidates.append( + ImageSegmentationCandidate( + candidate_id=f"candidate_{index}", + bbox_xyxy=_float_list(box), + score=_float_or_zero(_list_get(scores, index)), + mask_rle=_mask_or_none(_list_get(masks, index)), + ) + ) + return candidates + + +def _parse_candidate_item( + item: dict[str, Any], + index: int, +) -> ImageSegmentationCandidate: + known_keys = { + "candidate_id", + "id", + "index", + "bbox_xyxy", + "box_xyxy", + "box", + "bbox", + "score", + "mask_rle", + "mask", + "segmentation", + "mask_path", + "label", + } + mask_value = item.get("mask_rle") or item.get("mask") or item.get("segmentation") + return ImageSegmentationCandidate( + candidate_id=_string_or_none(item.get("candidate_id")) + or _string_or_none(item.get("id")) + or _index_id_or_none(item.get("index")) + or f"candidate_{index}", + bbox_xyxy=_float_list( + item.get("bbox_xyxy") + or item.get("box_xyxy") + or item.get("box") + or item.get("bbox") + ), + score=_float_or_zero(item.get("score")), + mask_rle=_mask_or_none(mask_value), + mask_path=_string_or_none(item.get("mask_path")), + label=_string_or_none(item.get("label")), + metadata={k: v for k, v in item.items() if k not in known_keys}, + ) + + +def _list_get(values: Any, index: int) -> Any: + if not isinstance(values, list) or index >= len(values): + return None + return values[index] + + +def _float_list(value: Any) -> list[float]: + if not isinstance(value, list): + return [] + parsed: list[float] = [] + for item in value: + try: + parsed.append(float(item)) + except (TypeError, ValueError): + continue + return parsed + + +def _float_or_zero(value: Any) -> float: + try: + return float(value) + except (TypeError, ValueError): + return 0.0 + + +def _float_or_none(value: Any) -> float | None: + try: + return float(value) + except (TypeError, ValueError): + return None + + +def _int_or_none(value: Any) -> int | None: + try: + return int(value) + except (TypeError, ValueError): + return None + + +def _string_or_none(value: Any) -> str | None: + return value if isinstance(value, str) else None + + +def _mask_or_none(value: Any) -> dict[str, Any] | None: + return value if isinstance(value, dict) else None + + +def _index_id_or_none(value: Any) -> str | None: + index = _int_or_none(value) + return f"candidate_{index}" if index is not None else None + + +def _parse_image_width(result_data: dict[str, Any]) -> int | None: + image_size = result_data.get("image_size") + if isinstance(image_size, dict): + width = _int_or_none(image_size.get("width")) + if width is not None: + return width + return _int_or_none(result_data.get("image_width")) + + +def _parse_image_height(result_data: dict[str, Any]) -> int | None: + image_size = result_data.get("image_size") + if isinstance(image_size, dict): + height = _int_or_none(image_size.get("height")) + if height is not None: + return height + return _int_or_none(result_data.get("image_height")) diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/clients/image_segmentation_client/schemas.py b/embodichain/gen_sim/prompt2scene/agent_tools/clients/image_segmentation_client/schemas.py new file mode 100644 index 000000000..3945bf4bd --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/clients/image_segmentation_client/schemas.py @@ -0,0 +1,103 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any + +from embodichain.gen_sim.prompt2scene.agent_tools.clients.common import ClientError + +__all__ = [ + "ImageSegmentationCandidate", + "ImageSegmentationError", + "ImageSegmentationResult", + "ImageSegmentationServerRequest", + "ImageSegmentationServerResponse", +] + + +@dataclass(frozen=True) +class ImageSegmentationServerRequest: + """Request sent to the SAM3 server. + + Args: + prompt: Short text concept prompt. + image_path: Local input image path. + """ + + prompt: str + image_path: str | Path + + def to_form_data(self) -> dict[str, str]: + """Convert the request to the SAM3 server multipart form fields.""" + return { + "prompt": self.prompt, + "score_threshold": "0.0", + "max_instances": "5", + } + + +@dataclass(frozen=True) +class ImageSegmentationCandidate: + """One SAM3 segmentation candidate for a prompted concept. + + SAM3 image inference returns parallel masks, boxes, and scores. The client + normalizes one aligned mask/box/score item into this candidate record. + """ + + candidate_id: str + bbox_xyxy: list[float] + score: float + mask_rle: dict[str, Any] | None = None + mask_path: str | None = None + label: str | None = None + metadata: dict[str, Any] = field(default_factory=dict) + + +@dataclass(frozen=True) +class ImageSegmentationResult: + """Successful SAM3 segmentation result.""" + + image_path: str + prompt: str + candidates: list[ImageSegmentationCandidate] + request_id: str | None = None + elapsed_sec: float | None = None + count: int | None = None + image_width: int | None = None + image_height: int | None = None + box_format: str = "xyxy" + mask_format: str | None = None + + +@dataclass(frozen=True) +class ImageSegmentationServerResponse: + """Parsed successful response from the SAM3 server.""" + + ok: bool + result: ImageSegmentationResult + status: str | None = None + error: str | None = None + status_code: int | None = None + content_type: str | None = None + headers: dict[str, str] = field(default_factory=dict) + + +@dataclass(frozen=True) +class ImageSegmentationError(ClientError): + """Image segmentation failure returned by the server.""" diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/clients/image_segmentation_client/utils.py b/embodichain/gen_sim/prompt2scene/agent_tools/clients/image_segmentation_client/utils.py new file mode 100644 index 000000000..834573588 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/clients/image_segmentation_client/utils.py @@ -0,0 +1,322 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from pathlib import Path +from typing import Any + +from PIL import Image, ImageDraw, ImageFont + +from embodichain.gen_sim.prompt2scene.agent_tools.clients.image_segmentation_client.schemas import ( + ImageSegmentationCandidate, +) +from embodichain.gen_sim.prompt2scene.utils.log import log_info + +__all__ = [ + "apply_mask_to_alpha", + "bbox_iou", + "decode_rle_mask", + "draw_labeled_bboxes", + "draw_numbered_bboxes", + "draw_numbered_masks", + "is_usable_segmentation_candidate", + "save_candidate_rgba_and_mask", + "sort_segments_by_bbox", +] + + +def decode_rle_mask(mask_rle: dict[str, Any]) -> Image.Image: + """Decode an uncompressed SAM3 RLE mask into a grayscale PIL image.""" + size = mask_rle.get("size") + counts = mask_rle.get("counts") + if not _is_size_pair(size): + raise ValueError("SAM3 mask_rle requires size=[height, width].") + if not isinstance(counts, list): + raise ValueError("SAM3 mask_rle counts must be an uncompressed list.") + + height = int(size[0]) + width = int(size[1]) + expected_pixels = height * width + starts_with = int(mask_rle.get("starts_with", 0)) + value = 255 if starts_with else 0 + pixels = bytearray(expected_pixels) + offset = 0 + + for count_value in counts: + count = int(count_value) + if count < 0: + raise ValueError("SAM3 mask_rle counts must be non-negative.") + next_offset = offset + count + if next_offset > expected_pixels: + raise ValueError("SAM3 mask_rle counts exceed the expected image size.") + if value: + pixels[offset:next_offset] = b"\xff" * count + offset = next_offset + value = 0 if value else 255 + + if offset != expected_pixels: + raise ValueError( + "SAM3 mask_rle counts do not cover the expected image size: " + f"{offset} != {expected_pixels}." + ) + return Image.frombytes("L", (width, height), bytes(pixels)) + + +def apply_mask_to_alpha( + image_path: str | Path, + mask: Image.Image, +) -> Image.Image: + """Return an RGBA image whose alpha channel is the provided mask.""" + image = Image.open(image_path).convert("RGBA") + alpha = mask.convert("L") + if alpha.size != image.size: + alpha = alpha.resize(image.size, Image.Resampling.NEAREST) + image.putalpha(alpha) + return image + + +def save_candidate_rgba_and_mask( + *, + image_path: str | Path, + candidate: ImageSegmentationCandidate, + output_dir: str | Path, + prefix: str | None = None, +) -> dict[str, str]: + """Save one candidate's mask image and RGBA image for SAM3D input.""" + if candidate.mask_rle is None: + raise ValueError(f"Candidate {candidate.candidate_id} has no mask_rle.") + + output_dir = Path(output_dir).expanduser().resolve() + output_dir.mkdir(parents=True, exist_ok=True) + filename_prefix = prefix or candidate.candidate_id + mask_path = output_dir / f"{filename_prefix}_mask.png" + rgba_path = output_dir / f"{filename_prefix}_rgba.png" + + mask = decode_rle_mask(candidate.mask_rle) + mask.save(mask_path) + rgba = apply_mask_to_alpha(image_path, mask) + rgba.save(rgba_path) + log_info(f"SAM3 mask written: {mask_path}") + log_info(f"SAM3 RGBA image written: {rgba_path}") + return { + "mask_path": str(mask_path), + "rgba_path": str(rgba_path), + } + + +def draw_numbered_bboxes( + *, + image_path: str | Path, + segments: list[dict[str, Any]], + output_path: str | Path, +) -> Path: + """Draw numbered bounding boxes for visual segmentation verification.""" + image = Image.open(image_path).convert("RGB") + draw = ImageDraw.Draw(image) + font = _load_label_font(image.width) + for index, segment in enumerate(segments, start=1): + _draw_bbox_label( + draw=draw, + bbox_xyxy=segment["bbox_xyxy"], + label=str(index), + font=font, + ) + + output_path = Path(output_path).expanduser().resolve() + output_path.parent.mkdir(parents=True, exist_ok=True) + image.save(output_path) + return output_path + + +def draw_numbered_masks( + *, + image_path: str | Path, + segments: list[dict[str, Any]], + output_path: str | Path, +) -> Path: + """Draw numbered segmentation masks for visual segmentation verification.""" + image = Image.open(image_path).convert("RGBA") + overlay = Image.new("RGBA", image.size, (0, 0, 0, 0)) + draw_overlay = ImageDraw.Draw(overlay) + font = _load_label_font(image.width) + colors = [ + (255, 64, 64, 110), + (64, 160, 255, 110), + (64, 220, 120, 110), + (255, 190, 64, 110), + (190, 96, 255, 110), + (255, 96, 190, 110), + ] + + for index, segment in enumerate(segments, start=1): + mask_rle = segment.get("mask_rle") + if mask_rle is None: + continue + mask = decode_rle_mask(mask_rle) + if mask.size != image.size: + mask = mask.resize(image.size, Image.Resampling.NEAREST) + color = colors[(index - 1) % len(colors)] + color_layer = Image.new("RGBA", image.size, color) + transparent = Image.new("RGBA", image.size) + overlay.alpha_composite(Image.composite(color_layer, transparent, mask)) + _draw_mask_label( + draw=draw_overlay, + segment=segment, + mask=mask, + label=str(index), + font=font, + ) + + result = Image.alpha_composite(image, overlay).convert("RGB") + output_path = Path(output_path).expanduser().resolve() + output_path.parent.mkdir(parents=True, exist_ok=True) + result.save(output_path) + return output_path + + +def draw_labeled_bboxes( + *, + image_path: str | Path, + boxes: list[dict[str, Any]], + output_path: str | Path, +) -> Path: + """Draw labeled bounding boxes for final segmentation visualization.""" + image = Image.open(image_path).convert("RGB") + draw = ImageDraw.Draw(image) + font = _load_label_font(image.width) + for box in boxes: + x1, y1, x2, y2 = box["bbox_xyxy"] + label = str(box["label"]) + _draw_bbox_label( + draw=draw, + bbox_xyxy=[x1, y1, x2, y2], + label=label, + font=font, + ) + + output_path = Path(output_path).expanduser().resolve() + output_path.parent.mkdir(parents=True, exist_ok=True) + image.save(output_path) + return output_path + + +def sort_segments_by_bbox(segments: list[dict[str, Any]]) -> list[dict[str, Any]]: + """Sort segments by top-left image position, then by descending score.""" + return sorted( + segments, + key=lambda segment: ( + float(segment["bbox_xyxy"][1]), + float(segment["bbox_xyxy"][0]), + -float(segment["score"]), + ), + ) + + +def bbox_iou(bbox_a: list[float], bbox_b: list[float]) -> float: + """Compute IoU for two xyxy bounding boxes.""" + ax1, ay1, ax2, ay2 = bbox_a + bx1, by1, bx2, by2 = bbox_b + ix1 = max(ax1, bx1) + iy1 = max(ay1, by1) + ix2 = min(ax2, bx2) + iy2 = min(ay2, by2) + iw = max(0.0, ix2 - ix1) + ih = max(0.0, iy2 - iy1) + intersection = iw * ih + area_a = max(0.0, ax2 - ax1) * max(0.0, ay2 - ay1) + area_b = max(0.0, bx2 - bx1) * max(0.0, by2 - by1) + union = area_a + area_b - intersection + return intersection / union if union > 0 else 0.0 + + +def is_usable_segmentation_candidate( + candidate: ImageSegmentationCandidate, +) -> bool: + """Return whether a candidate has the fields needed by downstream stages.""" + return candidate.mask_rle is not None and len(candidate.bbox_xyxy) == 4 + + +def _is_size_pair(value: Any) -> bool: + return ( + isinstance(value, list) + and len(value) == 2 + and isinstance(value[0], int) + and isinstance(value[1], int) + ) + + +def _load_label_font(image_width: int) -> ImageFont.ImageFont: + font_size = max(24, image_width // 80) + try: + return ImageFont.truetype("DejaVuSans-Bold.ttf", font_size) + except OSError: + return ImageFont.load_default() + + +def _draw_bbox_label( + *, + draw: ImageDraw.ImageDraw, + bbox_xyxy: list[float], + label: str, + font: ImageFont.ImageFont, +) -> None: + x1, y1, x2, y2 = bbox_xyxy + draw.rectangle((x1, y1, x2, y2), outline="red", width=6) + label_box = draw.textbbox((x1, y1), label, font=font) + padding = 8 + draw.rectangle( + ( + label_box[0] - padding, + label_box[1] - padding, + label_box[2] + padding, + label_box[3] + padding, + ), + fill="red", + ) + draw.text((x1, y1), label, fill="white", font=font) + + +def _draw_mask_label( + *, + draw: ImageDraw.ImageDraw, + segment: dict[str, Any], + mask: Image.Image, + label: str, + font: ImageFont.ImageFont, +) -> None: + bbox = mask.getbbox() + if bbox is None: + x1, y1, x2, y2 = segment["bbox_xyxy"] + x = float(x1 + x2) * 0.5 + y = float(y1 + y2) * 0.5 + else: + x1, y1, x2, y2 = bbox + x = float(x1 + x2) * 0.5 + y = float(y1 + y2) * 0.5 + + label_box = draw.textbbox((x, y), label, font=font) + padding = 8 + draw.rectangle( + ( + label_box[0] - padding, + label_box[1] - padding, + label_box[2] + padding, + label_box[3] + padding, + ), + fill="red", + ) + draw.text((x, y), label, fill="white", font=font) diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/managers/blender_rendering_manager/__init__.py b/embodichain/gen_sim/prompt2scene/agent_tools/managers/blender_rendering_manager/__init__.py new file mode 100644 index 000000000..32f8ef6cd --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/managers/blender_rendering_manager/__init__.py @@ -0,0 +1,31 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from embodichain.gen_sim.prompt2scene.agent_tools.managers.blender_rendering_manager.manager import ( + BlenderRenderingManager, +) +from embodichain.gen_sim.prompt2scene.agent_tools.managers.blender_rendering_manager.schemas import ( + RenderObjectScenesRequest, + RenderObjectScenesResult, +) + +__all__ = [ + "BlenderRenderingManager", + "RenderObjectScenesRequest", + "RenderObjectScenesResult", +] diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/managers/blender_rendering_manager/manager.py b/embodichain/gen_sim/prompt2scene/agent_tools/managers/blender_rendering_manager/manager.py new file mode 100644 index 000000000..8617f2975 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/managers/blender_rendering_manager/manager.py @@ -0,0 +1,175 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +import json +import subprocess +import tempfile +from pathlib import Path + +import numpy as np + +from embodichain.gen_sim.prompt2scene.agent_tools.managers.blender_rendering_manager.schemas import ( + RenderObjectScenesRequest, + RenderObjectScenesResult, +) + +__all__ = ["BlenderRenderingManager"] + + +class BlenderRenderingManager: + """Render simulation scenes through Blender's background CLI.""" + + def render_object_scenes( + self, + request: RenderObjectScenesRequest, + ) -> RenderObjectScenesResult: + """Render a front-oblique view of a collection of Z-up scenes.""" + output_path = request.output_path.expanduser().resolve() + output_path.parent.mkdir(parents=True, exist_ok=True) + with tempfile.TemporaryDirectory(prefix="p2s_blender_render_") as tmp_dir: + glb_paths = self._export_y_up_scenes( + request.object_scenes, + Path(tmp_dir), + ) + self._render_glbs( + glb_paths, + output_path, + timeout_seconds=request.timeout_seconds, + ) + return RenderObjectScenesResult(output_path=output_path) + + @staticmethod + def _export_y_up_scenes( + object_scenes: list[tuple[str, object]], + output_dir: Path, + ) -> list[Path]: + z_up_to_y_up = np.array( + [ + [1.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0], + [0.0, -1.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 1.0], + ], + dtype=np.float64, + ) + paths: list[Path] = [] + for object_id, scene in object_scenes: + path = output_dir / f"{object_id}_render.glb" + copied = scene.copy() + copied.apply_transform(z_up_to_y_up) + copied.export(path) + paths.append(path) + return paths + + @classmethod + def _render_glbs( + cls, + glb_paths: list[Path], + output_path: Path, + *, + timeout_seconds: int, + ) -> None: + script = cls._front_oblique_script(glb_paths, output_path) + with tempfile.NamedTemporaryFile( + mode="w", + suffix=".py", + encoding="utf-8", + delete=False, + ) as file: + script_path = Path(file.name) + file.write(script) + try: + subprocess.run( + ["blender", "--background", "--python", str(script_path)], + check=True, + timeout=timeout_seconds, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + ) + except subprocess.CalledProcessError as exc: + stderr_tail = (exc.stderr or "").strip()[-4000:] + raise RuntimeError( + f"Blender front-oblique render failed:\n{stderr_tail}" + ) from exc + finally: + script_path.unlink(missing_ok=True) + if not output_path.is_file(): + raise FileNotFoundError(f"Blender render was not written: {output_path}") + + @staticmethod + def _front_oblique_script(glb_paths: list[Path], output_path: Path) -> str: + object_paths_json = json.dumps([str(path.resolve()) for path in glb_paths]) + output_path_json = json.dumps(str(output_path.resolve())) + return f"""\ +import bpy +import json +import mathutils + +object_paths = json.loads({object_paths_json!r}) +output_path = json.loads({output_path_json!r}) +bpy.ops.object.select_all(action="SELECT") +bpy.ops.object.delete() +for path in object_paths: + bpy.ops.import_scene.gltf(filepath=path) +mesh_objects = [obj for obj in bpy.context.scene.objects if obj.type == "MESH"] +if not mesh_objects: + raise RuntimeError("No mesh objects were imported.") +min_corner = mathutils.Vector((float("inf"), float("inf"), float("inf"))) +max_corner = mathutils.Vector((float("-inf"), float("-inf"), float("-inf"))) +for obj in mesh_objects: + for corner in obj.bound_box: + world = obj.matrix_world @ mathutils.Vector(corner) + min_corner.x = min(min_corner.x, world.x) + min_corner.y = min(min_corner.y, world.y) + min_corner.z = min(min_corner.z, world.z) + max_corner.x = max(max_corner.x, world.x) + max_corner.y = max(max_corner.y, world.y) + max_corner.z = max(max_corner.z, world.z) +center = (min_corner + max_corner) * 0.5 +span_x = max(max_corner.x - min_corner.x, 1.0e-4) +span_y = max(max_corner.y - min_corner.y, 1.0e-4) +span_z = max(max_corner.z - min_corner.z, 1.0e-4) +camera_data = bpy.data.cameras.new("front_oblique_camera") +camera = bpy.data.objects.new("front_oblique_camera", camera_data) +bpy.context.collection.objects.link(camera) +view_distance = max(span_x, span_y, span_z) * 2.4 +camera.location = (center.x, center.y - view_distance, center.z + view_distance * 0.75) +camera.rotation_euler = (center - camera.location).to_track_quat("-Z", "Y").to_euler() +camera_data.type = "ORTHO" +camera_data.ortho_scale = max(span_x, span_y, span_z * 1.8) * 1.35 +bpy.context.scene.camera = camera +light_data = bpy.data.lights.new("front_oblique_area_light", "AREA") +light = bpy.data.objects.new("front_oblique_area_light", light_data) +bpy.context.collection.objects.link(light) +light.location = camera.location +light_data.energy = 600.0 +light_data.size = max(span_x, span_y) * 2.0 +bpy.context.scene.world.color = (1.0, 1.0, 1.0) +try: + bpy.context.scene.render.engine = "BLENDER_EEVEE_NEXT" +except Exception: + bpy.context.scene.render.engine = "BLENDER_EEVEE" +bpy.context.scene.render.resolution_x = 768 +bpy.context.scene.render.resolution_y = 768 +bpy.context.scene.render.film_transparent = False +bpy.context.scene.view_settings.view_transform = "Standard" +bpy.context.scene.view_settings.look = "Medium High Contrast" +bpy.context.scene.render.filepath = output_path +bpy.ops.render.render(write_still=True) +""" diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/managers/blender_rendering_manager/schemas.py b/embodichain/gen_sim/prompt2scene/agent_tools/managers/blender_rendering_manager/schemas.py new file mode 100644 index 000000000..e3f986c7f --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/managers/blender_rendering_manager/schemas.py @@ -0,0 +1,39 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +__all__ = ["RenderObjectScenesRequest", "RenderObjectScenesResult"] + + +@dataclass(frozen=True) +class RenderObjectScenesRequest: + """Request to render internal Z-up object scenes with Blender.""" + + object_scenes: list[tuple[str, Any]] + output_path: Path + timeout_seconds: int = 180 + + +@dataclass(frozen=True) +class RenderObjectScenesResult: + """Result of rendering object scenes.""" + + output_path: Path diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/managers/geometry_generation_manager/__init__.py b/embodichain/gen_sim/prompt2scene/agent_tools/managers/geometry_generation_manager/__init__.py new file mode 100644 index 000000000..ef8b93154 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/managers/geometry_generation_manager/__init__.py @@ -0,0 +1,45 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from embodichain.gen_sim.prompt2scene.agent_tools.managers.geometry_generation_manager.manager import ( + GeometryGenerationManager, +) +from embodichain.gen_sim.prompt2scene.agent_tools.managers.geometry_generation_manager.schemas import ( + GeometryGenerationRequest, + GeometryGenerationResult, + MultiObjectGenerationObject, + MultiObjectGenerationRequest, + MultiObjectGenerationResult, + RgbaImageToGeometryRequest, + RgbaImagesToGeometriesObject, + RgbaImagesToGeometriesRequest, + RgbaImagesToGeometriesResult, +) + +__all__ = [ + "GeometryGenerationManager", + "GeometryGenerationRequest", + "GeometryGenerationResult", + "MultiObjectGenerationObject", + "MultiObjectGenerationRequest", + "MultiObjectGenerationResult", + "RgbaImageToGeometryRequest", + "RgbaImagesToGeometriesObject", + "RgbaImagesToGeometriesRequest", + "RgbaImagesToGeometriesResult", +] diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/managers/geometry_generation_manager/manager.py b/embodichain/gen_sim/prompt2scene/agent_tools/managers/geometry_generation_manager/manager.py new file mode 100644 index 000000000..d30ea09aa --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/managers/geometry_generation_manager/manager.py @@ -0,0 +1,209 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from pathlib import Path + +from PIL import Image + +from embodichain.gen_sim.prompt2scene.agent_tools.clients.geometry_generation_client import ( + GeometryGenerationClient, + GeometryGenerationError, + GeometryGenerationServerRequest, + MultiObjectGenerationError, + MultiObjectGenerationServerRequest, +) +from embodichain.gen_sim.prompt2scene.agent_tools.managers.geometry_generation_manager.schemas import ( + GeometryGenerationRequest, + GeometryGenerationResult, + MultiObjectGenerationObject, + MultiObjectGenerationRequest, + MultiObjectGenerationResult, + RgbaImageToGeometryRequest, + RgbaImagesToGeometriesObject, + RgbaImagesToGeometriesRequest, + RgbaImagesToGeometriesResult, +) + + +class GeometryGenerationManager: + """Geometry generation domain operations.""" + + def __init__(self, *, client: GeometryGenerationClient | None = None) -> None: + self.client = client or GeometryGenerationClient() + + def generate_single_object_mesh( + self, + request: GeometryGenerationRequest, + ) -> GeometryGenerationResult: + image_path = request.image_path.expanduser().resolve() + output_path = request.output_path.expanduser().resolve() + _validate_single_object_request(image_path=image_path, output_path=output_path) + + response = self.client.generate( + GeometryGenerationServerRequest( + image_path=image_path, + output_path=output_path, + ), + ) + if isinstance(response, GeometryGenerationError): + raise RuntimeError(response.error_message) + + return GeometryGenerationResult( + output_path=Path(response.result.geometry_path).expanduser().resolve(), + ) + + def generate_multi_object_meshes( + self, + request: MultiObjectGenerationRequest, + ) -> MultiObjectGenerationResult: + image_path = request.image_path.expanduser().resolve() + output_dir = request.output_dir.expanduser().resolve() + _validate_multi_object_request( + image_path=image_path, + mask_paths=request.mask_paths, + output_dir=output_dir, + ) + + response = self.client.generate_multiple_objects( + MultiObjectGenerationServerRequest( + image_path=image_path, + mask_paths=[p.expanduser().resolve() for p in request.mask_paths], + ), + output_dir=output_dir, + ) + if isinstance(response, MultiObjectGenerationError): + raise RuntimeError(response.error_message) + + objects = [ + MultiObjectGenerationObject( + name=item.name, + geometry_path=Path(item.geometry_path).expanduser().resolve(), + rotation_quaternion_wxyz=item.rotation_quaternion_wxyz, + translation=item.translation, + scale=item.scale, + ) + for item in response.result.objects + ] + return MultiObjectGenerationResult(objects=objects) + + def convert_rgba_image_to_geometry( + self, + request: RgbaImageToGeometryRequest, + ) -> Path: + image_path = request.image_path.expanduser().resolve() + output_path = request.output_path.expanduser().resolve() + _validate_rgba_image(image_path) + + result = self.generate_single_object_mesh( + GeometryGenerationRequest(image_path=image_path, output_path=output_path) + ) + return _postprocess_mesh(result.output_path) + + def convert_rgba_images_to_geometries( + self, + request: RgbaImagesToGeometriesRequest, + ) -> RgbaImagesToGeometriesResult: + image_path = request.image_path.expanduser().resolve() + output_dir = request.output_dir.expanduser().resolve() + _validate_rgba_images_request(image_path, request.mask_paths) + + result = self.generate_multi_object_meshes( + MultiObjectGenerationRequest( + image_path=image_path, + mask_paths=request.mask_paths, + output_dir=output_dir, + ) + ) + objects = [ + RgbaImagesToGeometriesObject( + name=item.name, + geometry_path=_postprocess_mesh(item.geometry_path), + rotation_quaternion_wxyz=item.rotation_quaternion_wxyz, + translation=item.translation, + scale=item.scale, + ) + for item in result.objects + ] + return RgbaImagesToGeometriesResult(objects=objects) + + +def _validate_single_object_request(*, image_path: Path, output_path: Path) -> None: + if not image_path.is_file(): + raise FileNotFoundError(f"Geometry generation input not found: {image_path}") + if output_path.suffix.lower() != ".glb": + raise ValueError("Geometry generation output_path must be a GLB file path.") + if output_path.exists() and output_path.is_dir(): + raise ValueError(f"Geometry generation output_path is a directory: {output_path}") + + +def _validate_multi_object_request( + *, + image_path: Path, + mask_paths: list[Path], + output_dir: Path, +) -> None: + if not image_path.is_file(): + raise FileNotFoundError( + f"Multi-object geometry generation input not found: {image_path}" + ) + if not mask_paths: + raise ValueError("mask_paths must be non-empty.") + for mask_path in mask_paths: + mask_path_resolved = mask_path.expanduser().resolve() + if not mask_path_resolved.is_file(): + raise FileNotFoundError( + f"Multi-object geometry mask not found: {mask_path_resolved}" + ) + if output_dir.exists() and not output_dir.is_dir(): + raise ValueError( + f"Multi-object geometry output_dir is not a directory: {output_dir}" + ) + + +def _validate_rgba_image(image_path: Path) -> None: + if not image_path.is_file(): + raise FileNotFoundError(f"RGBA image not found: {image_path}") + + with Image.open(image_path) as image: + if image.mode in {"RGBA", "LA"}: + return + if image.mode == "P" and "transparency" in image.info: + return + raise ValueError( + "Geometry tool requires an image with an alpha channel, " + f"got mode={image.mode!r}: {image_path}" + ) + + +def _validate_rgba_images_request( + image_path: Path, + mask_paths: list[Path], +) -> None: + if not image_path.is_file(): + raise FileNotFoundError(f"Scene image not found: {image_path}") + with Image.open(image_path): + pass + if not mask_paths: + raise ValueError("mask_paths must be non-empty.") + for mask_path in mask_paths: + if not mask_path.expanduser().resolve().is_file(): + raise FileNotFoundError(f"Mask not found: {mask_path}") + + +def _postprocess_mesh(mesh_path: Path) -> Path: + return mesh_path.expanduser().resolve() diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/managers/geometry_generation_manager/schemas.py b/embodichain/gen_sim/prompt2scene/agent_tools/managers/geometry_generation_manager/schemas.py new file mode 100644 index 000000000..81f6816a8 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/managers/geometry_generation_manager/schemas.py @@ -0,0 +1,105 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path + + +@dataclass(frozen=True) +class RgbaImageToGeometryRequest: + """Request for converting one RGBA asset image to one mesh.""" + + image_path: Path + output_path: Path + + +@dataclass(frozen=True) +class RgbaImagesToGeometriesRequest: + """Request for converting a scene image with object masks to meshes.""" + + image_path: Path + mask_paths: list[Path] + output_dir: Path + + +@dataclass(frozen=True) +class RgbaImagesToGeometriesObject: + """One generated object mesh and its scene placement.""" + + name: str + geometry_path: Path + rotation_quaternion_wxyz: list[float] + translation: list[float] + scale: list[float] + + +@dataclass(frozen=True) +class RgbaImagesToGeometriesResult: + """Result of multi-object geometry generation.""" + + objects: list[RgbaImagesToGeometriesObject] + + @property + def geometry_paths(self) -> list[Path]: + return [item.geometry_path for item in self.objects] + + +@dataclass(frozen=True) +class GeometryGenerationRequest: + """Request for generating one object mesh from one image.""" + + image_path: Path + output_path: Path + + +@dataclass(frozen=True) +class GeometryGenerationResult: + """Generated mesh path.""" + + output_path: Path + + +@dataclass(frozen=True) +class MultiObjectGenerationRequest: + """Request to generate multiple object meshes from one image and masks.""" + + image_path: Path + mask_paths: list[Path] + output_dir: Path + + +@dataclass(frozen=True) +class MultiObjectGenerationObject: + """One generated object mesh and its scene placement.""" + + name: str + geometry_path: Path + rotation_quaternion_wxyz: list[float] + translation: list[float] + scale: list[float] + + +@dataclass(frozen=True) +class MultiObjectGenerationResult: + """Result of multi-object geometry generation.""" + + objects: list[MultiObjectGenerationObject] + + @property + def geometry_paths(self) -> list[Path]: + return [item.geometry_path for item in self.objects] diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/managers/geometry_manager/__init__.py b/embodichain/gen_sim/prompt2scene/agent_tools/managers/geometry_manager/__init__.py new file mode 100644 index 000000000..7d70c81c9 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/managers/geometry_manager/__init__.py @@ -0,0 +1,69 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from embodichain.gen_sim.prompt2scene.agent_tools.managers.geometry_manager.manager import ( + DEFAULT_INPUT_UP_AXIS, + DEFAULT_UP_AXIS, + GeometryManager, +) +from embodichain.gen_sim.prompt2scene.agent_tools.managers.geometry_manager.schemas import ( + AlignToAxisRequest, + AlignToAxisResult, + AlignXYLongAxisRequest, + AlignXYLongAxisResult, + CenterMeshRequest, + CenterMeshResult, + ConvertUpAxisRequest, + ConvertUpAxisResult, + DetectTabletopRequest, + DetectTabletopResult, + ExportMeshRequest, + ExportMeshResult, + LoadMeshRequest, + LoadMeshResult, + NormalizeRequest, + NormalizeResult, + PlaceAbovePlaneRequest, + PlaceAbovePlaneResult, + SupportPlaneCandidate, +) + +__all__ = [ + "AlignToAxisRequest", + "AlignToAxisResult", + "AlignXYLongAxisRequest", + "AlignXYLongAxisResult", + "CenterMeshRequest", + "CenterMeshResult", + "ConvertUpAxisRequest", + "ConvertUpAxisResult", + "DEFAULT_INPUT_UP_AXIS", + "DEFAULT_UP_AXIS", + "DetectTabletopRequest", + "DetectTabletopResult", + "ExportMeshRequest", + "ExportMeshResult", + "GeometryManager", + "LoadMeshRequest", + "LoadMeshResult", + "NormalizeRequest", + "NormalizeResult", + "PlaceAbovePlaneRequest", + "PlaceAbovePlaneResult", + "SupportPlaneCandidate", +] diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/managers/geometry_manager/manager.py b/embodichain/gen_sim/prompt2scene/agent_tools/managers/geometry_manager/manager.py new file mode 100644 index 000000000..2e5c88ab3 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/managers/geometry_manager/manager.py @@ -0,0 +1,584 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +"""Geometry manager for mesh I/O, transforms, and tabletop detection.""" + +from __future__ import annotations + +from pathlib import Path +from typing import Any + +import numpy as np +import trimesh + +from embodichain.gen_sim.prompt2scene.agent_tools.managers.geometry_manager.schemas import ( + AlignToAxisRequest, + AlignToAxisResult, + AlignXYLongAxisRequest, + AlignXYLongAxisResult, + CenterMeshRequest, + CenterMeshResult, + ConvertUpAxisRequest, + ConvertUpAxisResult, + DetectTabletopRequest, + DetectTabletopResult, + ExportMeshRequest, + ExportMeshResult, + LoadMeshRequest, + LoadMeshResult, + NormalizeRequest, + NormalizeResult, + PlaceAbovePlaneRequest, + PlaceAbovePlaneResult, + SupportPlaneCandidate, +) + +__all__ = ["GeometryManager"] + +DEFAULT_INPUT_UP_AXIS = [0.0, 1.0, 0.0] +DEFAULT_UP_AXIS = [0.0, 0.0, 1.0] + + +class GeometryManager: + """Manager for mesh geometry operations. + + Provides typed methods for mesh I/O, axis conversion, bounding-box + transforms, tabletop plane detection, and PCA alignment, following + the same pattern as service clients. + """ + + + @staticmethod + def load_mesh(request: LoadMeshRequest) -> LoadMeshResult: + """Load a GLB/mesh file as one Trimesh object.""" + mesh_path = request.mesh_path.expanduser().resolve() + if not mesh_path.is_file(): + raise FileNotFoundError(f"Mesh file not found: {mesh_path}") + + loaded = trimesh.load(mesh_path, force=None) + if isinstance(loaded, trimesh.Scene): + geometries = [ + g + for g in loaded.dump(concatenate=False) + if hasattr(g, "vertices") and hasattr(g, "faces") + ] + if not geometries: + raise ValueError(f"Scene contains no mesh geometry: {mesh_path}") + return LoadMeshResult(mesh=trimesh.util.concatenate(geometries)) + return LoadMeshResult(mesh=loaded) + + @staticmethod + def export_mesh(request: ExportMeshRequest) -> ExportMeshResult: + """Export a mesh and return the resolved output path.""" + output_path = request.output_path.expanduser().resolve() + output_path.parent.mkdir(parents=True, exist_ok=True) + request.mesh.export(output_path) + if not output_path.is_file(): + raise FileNotFoundError(f"Mesh was not written: {output_path}") + return ExportMeshResult(output_path=output_path) + + + @staticmethod + def convert_up_axis(request: ConvertUpAxisRequest) -> ConvertUpAxisResult: + """Convert a mesh from one up-axis convention to another.""" + mesh = GeometryManager._align_vector_to_axis( + request.mesh, + source_axis=request.input_up_axis or DEFAULT_INPUT_UP_AXIS, + target_axis=request.output_up_axis or DEFAULT_UP_AXIS, + ) + return ConvertUpAxisResult(mesh=mesh) + + @staticmethod + def center_by_bbox(request: CenterMeshRequest) -> CenterMeshResult: + """Center a mesh by its bounding box.""" + GeometryManager._validate_mesh(request.mesh) + + bounds = np.asarray(request.mesh.bounds, dtype=float) + if bounds.shape != (2, 3): + raise ValueError("Mesh bounds must have shape (2, 3).") + + bbox_center = (bounds[0] + bounds[1]) * 0.5 + centered = request.mesh.copy() + centered.apply_translation(-bbox_center) + return CenterMeshResult( + mesh=centered, + bbox_center=[float(v) for v in bbox_center], + ) + + @staticmethod + def align_to_axis(request: AlignToAxisRequest) -> AlignToAxisResult: + """Rotate a mesh so a source vector aligns to a target axis.""" + mesh = GeometryManager._align_vector_to_axis( + request.mesh, + source_axis=request.source_axis, + target_axis=request.target_axis, + ) + return AlignToAxisResult(mesh=mesh) + + @staticmethod + def place_above_plane( + request: PlaceAbovePlaneRequest, + ) -> PlaceAbovePlaneResult: + """Translate a mesh so its AABB bottom is above the XY plane.""" + if request.clearance < 0.0: + raise ValueError("clearance must be non-negative.") + + bounds = np.asarray(request.mesh.bounds, dtype=float) + if bounds.shape != (2, 3): + raise ValueError("Mesh bounds must have shape (2, 3).") + + min_z = float(bounds[0][2]) + placed = request.mesh.copy() + placed.apply_translation([0.0, 0.0, request.clearance - min_z]) + return PlaceAbovePlaneResult(mesh=placed) + + @staticmethod + def normalize(request: NormalizeRequest) -> NormalizeResult: + """Scale a mesh so its longest bounding-box axis equals target_size.""" + if request.target_size <= 0.0: + raise ValueError("target_size must be positive.") + + extents = np.asarray( + request.mesh.bounding_box_oriented.primitive.extents, dtype=float + ) + scale_factor = request.target_size / float(np.max(extents)) + normalized = request.mesh.copy() + normalized.apply_scale(scale_factor) + return NormalizeResult(mesh=normalized, scale_factor=scale_factor) + + @staticmethod + def mesh_aabb_size(mesh: Any) -> Any: + """Return a mesh AABB size vector.""" + bounds = np.asarray(mesh.bounds, dtype=np.float64) + if bounds.shape != (2, 3): + raise ValueError("Mesh bounds must have shape (2, 3).") + size = bounds[1] - bounds[0] + if np.any(size <= 0.0): + raise ValueError(f"Mesh AABB size must be positive, got {size.tolist()}.") + return size + + @staticmethod + def bbox_ratio(size: Any) -> Any: + """Return bbox dimensions normalized by the largest axis.""" + size = np.asarray(size, dtype=np.float64) + max_size = float(np.max(size)) + if max_size <= 0.0: + raise ValueError("bbox size max must be positive.") + return size / max_size + + @staticmethod + def best_axis_bbox_scale_match( + *, + source_size_cm: Any, + target_size_cm: Any, + ) -> dict[str, Any]: + """Match target bbox axes to source axes and return a scale candidate.""" + source = np.asarray(source_size_cm, dtype=np.float64) + target = np.asarray(target_size_cm, dtype=np.float64) + if source.shape != (3,) or target.shape != (3,): + raise ValueError("source_size_cm and target_size_cm must have shape (3,).") + if np.any(source <= 0.0) or np.any(target <= 0.0): + raise ValueError("source_size_cm and target_size_cm must be positive.") + + source_ratio = GeometryManager.bbox_ratio(source) + best: dict[str, Any] | None = None + for permutation in [ + (0, 1, 2), + (0, 2, 1), + (1, 0, 2), + (1, 2, 0), + (2, 0, 1), + (2, 1, 0), + ]: + target_perm = target[list(permutation)] + target_ratio = GeometryManager.bbox_ratio(target_perm) + ratio_error = GeometryManager._mean_abs_log_ratio_error( + source_ratio, + target_ratio, + ) + per_axis_scale = target_perm / source + candidate = { + "target_permutation": list(permutation), + "source_size_cm": source.tolist(), + "target_size_cm_original_order": target.tolist(), + "target_size_cm_matched_to_source_axes": target_perm.tolist(), + "source_ratio": source_ratio.tolist(), + "target_ratio_matched": target_ratio.tolist(), + "per_axis_scale": per_axis_scale.tolist(), + "scale_factor": float(np.median(per_axis_scale)), + "shape_ratio_error": float(ratio_error), + } + if best is None or ratio_error < float(best["shape_ratio_error"]): + best = candidate + if best is None: + raise ValueError("Failed to match bbox axes.") + return best + + @staticmethod + def scene_to_mesh(scene: Any) -> Any: + """Convert a trimesh Scene or mesh-like object to one mesh.""" + if isinstance(scene, trimesh.Trimesh): + return scene + dumped = scene.dump(concatenate=True) + if isinstance(dumped, trimesh.Trimesh): + return dumped + meshes = [item for item in dumped if isinstance(item, trimesh.Trimesh)] + if not meshes: + raise ValueError("Scene contains no mesh geometry.") + return trimesh.util.concatenate(meshes) + + @staticmethod + def detect_tabletop( + request: DetectTabletopRequest, + ) -> DetectTabletopResult: + """Detect the most likely tabletop plane in a mesh.""" + candidates = GeometryManager._find_support_plane_candidates( + request.mesh, + normal_angle_tol_deg=request.normal_angle_tol_deg, + plane_distance_tol=request.plane_distance_tol, + min_area_ratio=request.min_area_ratio, + max_candidates=request.max_candidates, + ) + selected = GeometryManager._select_tabletop_plane(candidates) + oriented_normal = GeometryManager._orient_plane_normal( + request.mesh, + plane_normal=selected.normal, + plane_center=selected.center, + ) + return DetectTabletopResult( + selected=selected, + oriented_normal=oriented_normal, + candidates=candidates, + ) + + + @staticmethod + def align_xy_long_axis( + request: AlignXYLongAxisRequest, + ) -> AlignXYLongAxisResult: + """Rotate a table so its XY-projected long axis aligns with the Y axis.""" + vertices = np.asarray(request.mesh.vertices, dtype=float) + xy_vertices = GeometryManager._select_xy_vertices( + request.mesh, vertices, request.face_indices + ) + if xy_vertices.shape[0] < 2: + raise ValueError( + "Mesh must contain at least two vertices for PCA alignment." + ) + + centered_xy = xy_vertices - np.mean(xy_vertices, axis=0) + covariance = centered_xy.T @ centered_xy / max(centered_xy.shape[0] - 1, 1) + eigenvalues, eigenvectors = np.linalg.eigh(covariance) + long_axis = eigenvectors[:, int(np.argmax(eigenvalues))] + if float(np.linalg.norm(long_axis)) == 0.0: + raise ValueError("PCA long axis is degenerate.") + + axis_angle = float(np.arctan2(long_axis[1], long_axis[0])) + rotation_angle = GeometryManager._minimal_angle_to_align_axis( + axis_angle, np.pi / 2.0 + ) + rotation = GeometryManager._z_axis_rotation_transform(rotation_angle) + aligned = request.mesh.copy() + aligned.apply_transform(rotation) + return AlignXYLongAxisResult( + mesh=aligned, + yaw_angle_degrees=float(np.rad2deg(rotation_angle)), + ) + + + @staticmethod + def _align_vector_to_axis( + mesh: Any, + *, + source_axis: list[float], + target_axis: list[float], + ) -> Any: + source = GeometryManager._normalize( + np.asarray(source_axis, dtype=float) + ) + target = GeometryManager._normalize( + np.asarray(target_axis, dtype=float) + ) + if np.linalg.norm(source) == 0: + raise ValueError("source_axis must be non-zero.") + if np.linalg.norm(target) == 0: + raise ValueError("target_axis must be non-zero.") + + transform = GeometryManager._rotation_transform_between_vectors( + source, target + ) + aligned = mesh.copy() + aligned.apply_transform(transform) + return aligned + + + @staticmethod + def _find_support_plane_candidates( + mesh: Any, + *, + normal_angle_tol_deg: float = 8.0, + plane_distance_tol: float | None = None, + min_area_ratio: float = 0.02, + max_candidates: int = 24, + ) -> list[SupportPlaneCandidate]: + GeometryManager._validate_mesh(mesh) + + normals = np.asarray(mesh.face_normals, dtype=float) + centers = np.asarray(mesh.triangles_center, dtype=float) + areas = np.asarray(mesh.area_faces, dtype=float) + vertices = np.asarray(mesh.vertices, dtype=float) + total_area = float(np.sum(areas)) + if total_area <= 0: + raise ValueError("Mesh has no positive face area.") + + if plane_distance_tol is None: + extent = float( + np.linalg.norm(np.asarray(mesh.extents, dtype=float)) + ) + plane_distance_tol = max(extent * 0.01, 1e-4) + + cos_tol = float(np.cos(np.deg2rad(normal_angle_tol_deg))) + min_area = total_area * min_area_ratio + order = np.argsort(-areas) + used = np.zeros(len(areas), dtype=bool) + candidates: list[SupportPlaneCandidate] = [] + + for seed_index in order: + if used[seed_index]: + continue + seed_normal = GeometryManager._normalize(normals[seed_index]) + if np.linalg.norm(seed_normal) == 0: + used[seed_index] = True + continue + + seed_center = centers[seed_index] + seed_offset = float(np.dot(seed_normal, seed_center)) + normal_match = normals @ seed_normal >= cos_tol + offsets = centers @ seed_normal + plane_match = np.abs(offsets - seed_offset) <= plane_distance_tol + face_mask = normal_match & plane_match & ~used + face_indices = np.flatnonzero(face_mask) + if len(face_indices) == 0: + used[seed_index] = True + continue + + used[face_indices] = True + area = float(np.sum(areas[face_indices])) + if area < min_area: + continue + + weighted_normal = GeometryManager._normalize( + np.sum( + normals[face_indices] * areas[face_indices, None], axis=0 + ), + ) + center = ( + np.sum( + centers[face_indices] * areas[face_indices, None], axis=0 + ) + / area + ) + candidate = GeometryManager._build_candidate( + normal=weighted_normal, + center=center, + area=area, + face_indices=face_indices, + vertices=vertices, + ) + candidates.append(candidate) + + candidates.sort(key=lambda c: c.score, reverse=True) + return candidates[:max_candidates] + + @staticmethod + def _select_tabletop_plane( + candidates: list[SupportPlaneCandidate], + ) -> SupportPlaneCandidate: + if not candidates: + raise ValueError("No support-plane candidates were found.") + return max(candidates, key=lambda c: c.score) + + @staticmethod + def _orient_plane_normal( + mesh: Any, + *, + plane_normal: list[float], + plane_center: list[float], + ) -> list[float]: + GeometryManager._validate_mesh(mesh) + + normal = GeometryManager._normalize( + np.asarray(plane_normal, dtype=float) + ) + center = np.asarray(plane_center, dtype=float) + if np.linalg.norm(normal) == 0: + raise ValueError("plane_normal must be non-zero.") + + vertices = np.asarray(mesh.vertices, dtype=float) + signed_distances = (vertices - center) @ normal + positive_mask = signed_distances > 1e-6 + negative_mask = signed_distances < -1e-6 + positive_score = float(np.sum(np.abs(signed_distances[positive_mask]))) + negative_score = float(np.sum(np.abs(signed_distances[negative_mask]))) + + if positive_score > negative_score: + normal = -normal + return [float(v) for v in normal] + + @staticmethod + def _build_candidate( + *, + normal: Any, + center: Any, + area: float, + face_indices: Any, + vertices: Any, + ) -> SupportPlaneCandidate: + signed_distances = (vertices - center) @ normal + below_mask = signed_distances < -1e-6 + above_mask = signed_distances > 1e-6 + below_count = int(np.count_nonzero(below_mask)) + above_count = int(np.count_nonzero(above_mask)) + below_score = float(np.sum(np.abs(signed_distances[below_mask]))) + above_score = float(np.sum(np.abs(signed_distances[above_mask]))) + + smaller_score = min(below_score, above_score) + larger_score = max(below_score, above_score) + asymmetry_score = min( + (larger_score + 1e-9) / (smaller_score + 1e-9), 10.0 + ) + score = float(area * asymmetry_score) + return SupportPlaneCandidate( + normal=[float(v) for v in normal], + center=[float(v) for v in center], + area=area, + face_indices=[int(i) for i in face_indices], + below_vertex_count=below_count, + above_vertex_count=above_count, + below_area_score=below_score, + above_area_score=above_score, + score=score, + ) + + + @staticmethod + def _select_xy_vertices( + mesh: Any, + vertices: Any, + face_indices: list[int] | None, + ) -> Any: + if face_indices is None: + return vertices[:, :2] + + faces = np.asarray(mesh.faces, dtype=int) + selected_faces = faces[np.asarray(face_indices, dtype=int)] + selected_vertex_indices = np.unique(selected_faces.reshape(-1)) + return vertices[selected_vertex_indices, :2] + + @staticmethod + def _minimal_angle_to_align_axis( + source_angle: float, target_angle: float + ) -> float: + candidates = [ + GeometryManager._wrap_to_pi(target_angle - source_angle), + GeometryManager._wrap_to_pi( + target_angle + 3.141592653589793 - source_angle + ), + ] + return min(candidates, key=abs) + + @staticmethod + def _wrap_to_pi(angle: float) -> float: + two_pi = 2.0 * 3.141592653589793 + return (angle + 3.141592653589793) % two_pi - 3.141592653589793 + + @staticmethod + def _z_axis_rotation_transform(angle: float) -> Any: + c = float(np.cos(angle)) + s = float(np.sin(angle)) + transform = np.eye(4) + transform[:3, :3] = np.array( + [ + [c, -s, 0.0], + [s, c, 0.0], + [0.0, 0.0, 1.0], + ], + dtype=float, + ) + return transform + + + @staticmethod + def _rotation_transform_between_vectors( + source: Any, target: Any + ) -> Any: + dot = float(np.clip(np.dot(source, target), -1.0, 1.0)) + transform = np.eye(4) + if dot > 1.0 - 1e-8: + return transform + if dot < -1.0 + 1e-8: + axis = GeometryManager._orthogonal_axis(source) + rotation = GeometryManager._axis_angle_rotation(axis, np.pi) + else: + axis = GeometryManager._normalize(np.cross(source, target)) + angle = float(np.arccos(dot)) + rotation = GeometryManager._axis_angle_rotation(axis, angle) + transform[:3, :3] = rotation + return transform + + @staticmethod + def _axis_angle_rotation(axis: Any, angle: float) -> Any: + axis = GeometryManager._normalize(axis) + x, y, z = axis + c = float(np.cos(angle)) + s = float(np.sin(angle)) + one_c = 1.0 - c + return np.array( + [ + [c + x * x * one_c, x * y * one_c - z * s, x * z * one_c + y * s], + [y * x * one_c + z * s, c + y * y * one_c, y * z * one_c - x * s], + [z * x * one_c - y * s, z * y * one_c + x * s, c + z * z * one_c], + ], + dtype=float, + ) + + @staticmethod + def _orthogonal_axis(vector: Any) -> Any: + axis = np.array([1.0, 0.0, 0.0]) + if abs(float(np.dot(vector, axis))) > 0.9: + axis = np.array([0.0, 1.0, 0.0]) + return GeometryManager._normalize(np.cross(vector, axis)) + + @staticmethod + def _normalize(vector: Any) -> Any: + norm = float(np.linalg.norm(vector)) + if norm == 0.0: + return vector + return vector / norm + + @staticmethod + def _mean_abs_log_ratio_error(lhs: Any, rhs: Any) -> float: + eps = 1.0e-6 + lhs = np.maximum(np.asarray(lhs, dtype=np.float64), eps) + rhs = np.maximum(np.asarray(rhs, dtype=np.float64), eps) + return float(np.mean(np.abs(np.log(lhs / rhs)))) + + @staticmethod + def _validate_mesh(mesh: Any) -> None: + if not hasattr(mesh, "vertices") or not hasattr(mesh, "faces"): + raise ValueError("Loaded geometry is not a mesh.") + if len(mesh.vertices) == 0 or len(mesh.faces) == 0: + raise ValueError("Mesh must contain vertices and faces.") diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/managers/geometry_manager/scene_geometry.py b/embodichain/gen_sim/prompt2scene/agent_tools/managers/geometry_manager/scene_geometry.py new file mode 100644 index 000000000..be502fbbe --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/managers/geometry_manager/scene_geometry.py @@ -0,0 +1,567 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + + +from __future__ import annotations + +from typing import Any + +import numpy as np + +from embodichain.gen_sim.prompt2scene.agent_tools.managers.geometry_manager import ( + DetectTabletopRequest, + GeometryManager, +) + +__all__ = [ + "_compose_json_matrices", + "_compose_simready_to_aligned_matrix", + "_decompose_transform_matrix", + "_aabb_bottom_to_xy_plane_transform", + "_aabb_center", + "_compose_sam3d_multi_object_transform", + "_copy_scene_with_transform", + "_estimate_support_normal", + "_glb_to_sam3d_local_matrix", + "_load_scene_with_transform", + "_matrix_from_json", + "_quaternion_wxyz_to_matrix", + "_rotation_between_vectors", + "_row_linear_to_trimesh_matrix", + "_scale_transform", + "_scene_to_mesh", + "_support_normal_flip_transform", + "_transform_point", + "_validate_vector", + "_xy_aabb_center", + "_xy_aabb_size", + "_z_up_to_glb_y_up_transform", + "_z_yaw_transform", +] + + +def _compose_json_matrices(*values: Any) -> list[list[float]]: + matrices = [np.asarray(value, dtype=np.float64) for value in values] + if any(matrix.shape != (4, 4) for matrix in matrices): + return [] + result = np.eye(4, dtype=np.float64) + for matrix in matrices: + result = result @ matrix + return result.tolist() + + +def _compose_simready_to_aligned_matrix( + *, raw_to_aligned_matrix: Any, raw_to_simready_matrix: Any +) -> list[list[float]]: + raw_to_aligned = np.asarray(raw_to_aligned_matrix, dtype=np.float64) + raw_to_simready = np.asarray(raw_to_simready_matrix, dtype=np.float64) + if raw_to_aligned.shape != (4, 4) or raw_to_simready.shape != (4, 4): + return [] + try: + return (raw_to_aligned @ np.linalg.inv(raw_to_simready)).tolist() + except np.linalg.LinAlgError: + return [] + + +def _decompose_transform_matrix(matrix_value: Any) -> dict[str, Any]: + matrix = np.asarray(matrix_value, dtype=np.float64) + if matrix.shape != (4, 4): + return {"translation": [], "rotation_matrix": [], "scale": []} + linear = matrix[:3, :3] + scale = np.linalg.norm(linear, axis=0) + rotation = np.eye(3, dtype=np.float64) + for index in range(3): + if scale[index] > 1.0e-12: + rotation[:, index] = linear[:, index] / scale[index] + return { + "translation": matrix[:3, 3].tolist(), + "rotation_matrix": rotation.tolist(), + "scale": scale.tolist(), + } + + +def _support_normal_flip_transform( + *, + support_normal: np.ndarray, + normal_alignment: np.ndarray, +) -> np.ndarray: + flipped_normal_alignment = _rotation_between_vectors( + -support_normal, + np.array([0.0, 0.0, 1.0], dtype=np.float64), + ) + return flipped_normal_alignment @ np.linalg.inv(normal_alignment) + + +def _z_yaw_transform(yaw_degrees: float) -> np.ndarray: + angle = np.deg2rad(yaw_degrees) + c = float(np.cos(angle)) + s = float(np.sin(angle)) + transform = np.eye(4, dtype=np.float64) + transform[:3, :3] = np.array( + [ + [c, -s, 0.0], + [s, c, 0.0], + [0.0, 0.0, 1.0], + ], + dtype=np.float64, + ) + return transform + + +def _z_up_to_glb_y_up_transform() -> np.ndarray: + return _rotation_between_vectors( + np.array([0.0, 0.0, 1.0], dtype=np.float64), + np.array([0.0, 1.0, 0.0], dtype=np.float64), + ) + + +def _copy_scene_with_transform(scene: Any, transform: np.ndarray) -> Any: + copied = scene.copy() + copied.apply_transform(transform) + return copied + + +def _matrix_from_json(value: Any, *, name: str) -> np.ndarray: + matrix = np.asarray(value, dtype=np.float64) + if matrix.shape != (4, 4): + raise ValueError(f"{name} must be a 4x4 matrix.") + return matrix + + +def _load_scene_with_transform( + *, + path: Path, + transform: np.ndarray, + trimesh: Any, +) -> Any: + scene = trimesh.load(path, force="scene") + scene.apply_transform(transform) + return scene + + +def _scene_to_mesh(scene: Any, *, trimesh: Any) -> Any: + if isinstance(scene, trimesh.Trimesh): + return scene + dumped = scene.dump(concatenate=True) + if isinstance(dumped, trimesh.Trimesh): + return dumped + meshes = [item for item in dumped if isinstance(item, trimesh.Trimesh)] + if not meshes: + raise ValueError("Scene contains no mesh geometry.") + return trimesh.util.concatenate(meshes) + + +def _estimate_support_normal(mesh: Any) -> np.ndarray: + geom = GeometryManager() + try: + detect_result = geom.detect_tabletop(DetectTabletopRequest(mesh=mesh)) + normal = np.asarray(detect_result.oriented_normal, dtype=np.float64) + norm = np.linalg.norm(normal) + if norm > 0.0: + return normal / norm + except Exception: + pass + + normals = np.asarray(mesh.face_normals, dtype=np.float64) + areas = np.asarray(mesh.area_faces, dtype=np.float64) + if normals.size == 0 or areas.size == 0: + return np.array([0.0, 0.0, 1.0], dtype=np.float64) + normal = normals[int(np.argmax(areas))] + norm = np.linalg.norm(normal) + if norm == 0.0: + return np.array([0.0, 0.0, 1.0], dtype=np.float64) + return normal / norm + + +def _rotation_between_vectors(source: np.ndarray, target: np.ndarray) -> np.ndarray: + source = source / np.linalg.norm(source) + target = target / np.linalg.norm(target) + cross = np.cross(source, target) + dot = float(np.clip(np.dot(source, target), -1.0, 1.0)) + if np.linalg.norm(cross) < 1e-8: + if dot > 0.0: + return np.eye(4, dtype=np.float64) + axis = np.array([1.0, 0.0, 0.0], dtype=np.float64) + if abs(float(np.dot(source, axis))) > 0.9: + axis = np.array([0.0, 1.0, 0.0], dtype=np.float64) + cross = np.cross(source, axis) + axis = cross / np.linalg.norm(cross) + angle = float(np.arccos(dot)) + skew = np.array( + [ + [0.0, -axis[2], axis[1]], + [axis[2], 0.0, -axis[0]], + [-axis[1], axis[0], 0.0], + ], + dtype=np.float64, + ) + rotation = ( + np.eye(3, dtype=np.float64) + + np.sin(angle) * skew + + (1.0 - np.cos(angle)) * (skew @ skew) + ) + transform = np.eye(4, dtype=np.float64) + transform[:3, :3] = rotation + return transform + + +def _transform_point(transform: np.ndarray, point: np.ndarray) -> np.ndarray: + homogeneous = np.ones(4, dtype=np.float64) + homogeneous[:3] = point + return (transform @ homogeneous)[:3] + + +def _aabb_center(bounds: np.ndarray) -> np.ndarray: + return 0.5 * ( + np.asarray(bounds[0], dtype=np.float64) + + np.asarray(bounds[1], dtype=np.float64) + ) + + +def _xy_aabb_center(bounds: np.ndarray) -> np.ndarray: + bounds = np.asarray(bounds, dtype=np.float64) + return 0.5 * (bounds[0, :2] + bounds[1, :2]) + + +def _xy_aabb_size(bounds: np.ndarray) -> np.ndarray: + bounds = np.asarray(bounds, dtype=np.float64) + return np.maximum(bounds[1, :2] - bounds[0, :2], 1e-6) + + +def _aabb_bottom_to_xy_plane_transform(bounds: np.ndarray) -> np.ndarray: + bounds = np.asarray(bounds, dtype=np.float64) + min_z = float(bounds[0][2]) + transform = np.eye(4, dtype=np.float64) + transform[:3, 3] = [0.0, 0.0, -min_z] + return transform + + +def _scale_transform(scale: float) -> np.ndarray: + transform = np.eye(4, dtype=np.float64) + transform[:3, :3] *= float(scale) + return transform + + +def _compose_sam3d_multi_object_transform( + *, + rotation_quaternion_wxyz: list[float], + translation: list[float], + scale: list[float], +) -> np.ndarray: + """Compose the transform equivalent to the old baked multi-object export.""" + rotation = _quaternion_wxyz_to_matrix(rotation_quaternion_wxyz) + scale_matrix = np.diag(_validate_vector(scale, expected_len=3, name="scale")) + linear_row = _glb_to_sam3d_local_matrix() @ scale_matrix @ rotation + return _row_linear_to_trimesh_matrix( + linear_row=linear_row, + translation=translation, + ) + + +def _row_linear_to_trimesh_matrix( + *, + linear_row: np.ndarray, + translation: list[float], +) -> np.ndarray: + """Convert a row-vector linear transform to trimesh's 4x4 matrix format.""" + translation_vector = _validate_vector( + translation, + expected_len=3, + name="translation", + ) + transform = np.eye(4, dtype=np.float64) + transform[:3, :3] = linear_row.T + transform[:3, 3] = translation_vector + return transform + + +def _validate_vector( + values: list[float], + *, + expected_len: int, + name: str, +) -> np.ndarray: + """Validate and convert a numeric vector.""" + if len(values) != expected_len: + raise ValueError(f"{name} must have {expected_len} values") + return np.asarray(values, dtype=np.float64) + + +def _glb_to_sam3d_local_matrix() -> np.ndarray: + """Return the basis conversion used by the old baked multi-object exporter.""" + return np.array( + [ + [1.0, 0.0, 0.0], + [0.0, 0.0, 1.0], + [0.0, -1.0, 0.0], + ], + dtype=np.float64, + ) + + +def _quaternion_wxyz_to_matrix(quaternion: list[float]) -> np.ndarray: + """Convert a wxyz quaternion to a 3x3 rotation matrix.""" + if len(quaternion) != 4: + raise ValueError("rotation_quaternion_wxyz must have 4 values") + w, x, y, z = [float(v) for v in quaternion] + norm = np.sqrt(w * w + x * x + y * y + z * z) + if norm == 0.0: + raise ValueError("rotation quaternion must be non-zero") + w, x, y, z = w / norm, x / norm, y / norm, z / norm + return np.array( + [ + [ + 1.0 - 2.0 * (y * y + z * z), + 2.0 * (x * y - z * w), + 2.0 * (x * z + y * w), + ], + [ + 2.0 * (x * y + z * w), + 1.0 - 2.0 * (x * x + z * z), + 2.0 * (y * z - x * w), + ], + [ + 2.0 * (x * z - y * w), + 2.0 * (y * z + x * w), + 1.0 - 2.0 * (x * x + y * y), + ], + ], + dtype=np.float64, + ) + + +def _detect_table_fit_support_quad( + mesh: Any, + *, + target_aspect: float, +) -> dict[str, Any]: + geom = GeometryManager() + detect = geom.detect_tabletop(DetectTabletopRequest(mesh=mesh)) + faces = np.asarray(mesh.faces, dtype=np.int64) + vertices = np.asarray(mesh.vertices, dtype=np.float64) + support_vertices = vertices[ + np.unique(faces[np.asarray(detect.selected.face_indices, dtype=np.int64)]) + ] + hull_xy = _table_fit_convex_hull_2d(support_vertices[:, :2]) + quad = _largest_centered_table_fit_inscribed_rect( + hull_xy, + target_aspect=max(float(target_aspect), 1.0e-6), + ) + center_z = float(np.mean(support_vertices[:, 2])) + return { + "method": "sampled_centered_inscribed_rectangle_on_support_convex_hull", + "normal": detect.oriented_normal, + "area": float(detect.selected.area), + "center": [quad["center_xy"][0], quad["center_xy"][1], center_z], + "center_xy": quad["center_xy"], + "size_xy": quad["size_xy"], + "yaw_radians": quad["yaw_radians"], + "yaw_degrees": float(np.rad2deg(quad["yaw_radians"])), + "corners_xy": quad["corners_xy"], + "support_hull_xy": hull_xy.tolist(), + } + + +def _largest_centered_table_fit_inscribed_rect( + hull_xy: np.ndarray, + *, + target_aspect: float, + yaw_samples: int = 180, +) -> dict[str, Any]: + if hull_xy.shape[0] < 3: + raise ValueError("Support hull must contain at least 3 points.") + best: dict[str, Any] | None = None + centers = [ + np.mean(hull_xy, axis=0), + 0.5 * (np.min(hull_xy, axis=0) + np.max(hull_xy, axis=0)), + ] + for yaw in np.linspace(0.0, np.pi, yaw_samples, endpoint=False): + rot = _table_fit_rot2(-yaw) + inv_rot = _table_fit_rot2(yaw) + rotated_hull = hull_xy @ rot.T + for center_world in centers: + center = center_world @ rot.T + lo = 0.0 + bbox_size = np.max(rotated_hull, axis=0) - np.min(rotated_hull, axis=0) + hi = float(max(bbox_size[0] / target_aspect, bbox_size[1], 1.0e-6)) + for _ in range(40): + mid = 0.5 * (lo + hi) + width = target_aspect * mid + depth = mid + corners = _table_fit_rect_corners( + center=center, + width=width, + depth=depth, + ) + corners_world = corners @ inv_rot.T + if all( + _table_fit_point_in_convex_polygon(point, hull_xy) + for point in corners_world + ): + lo = mid + else: + hi = mid + width = target_aspect * lo + depth = lo + area = width * depth + corners_world = ( + _table_fit_rect_corners(center=center, width=width, depth=depth) + @ inv_rot.T + ) + candidate = { + "center_xy": center_world.tolist(), + "size_xy": [float(width), float(depth)], + "yaw_radians": float(yaw), + "corners_xy": corners_world.tolist(), + "area": float(area), + } + if best is None or area > float(best["area"]): + best = candidate + if best is None: + raise ValueError("Failed to estimate an inscribed support rectangle.") + return best + + +def _load_table_fit_scene_internal_z( + path: Path, + *, + trimesh: Any, + y_to_z: np.ndarray, +) -> Any: + if not path.is_file(): + raise FileNotFoundError(f"GLB not found: {path}") + scene = trimesh.load(path, force="scene") + scene.apply_transform(y_to_z) + return scene + + +def _table_fit_scene_union_bounds(scenes: list[Any], *, trimesh: Any) -> np.ndarray: + bounds = [ + np.asarray(_scene_to_mesh(scene, trimesh=trimesh).bounds, dtype=np.float64) + for scene in scenes + ] + return np.vstack( + [ + np.vstack([item[0] for item in bounds]).min(axis=0), + np.vstack([item[1] for item in bounds]).max(axis=0), + ] + ) + + +def _table_fit_bounds_xy_manifest( + bounds: np.ndarray, + *, + unit_scale: float, +) -> dict[str, Any]: + min_xy = bounds[0, :2] * unit_scale + max_xy = bounds[1, :2] * unit_scale + size_xy = max_xy - min_xy + center_xy = 0.5 * (min_xy + max_xy) + return { + "unit": "cm", + "min_xy": min_xy.tolist(), + "max_xy": max_xy.tolist(), + "center_xy": center_xy.tolist(), + "size_xy": size_xy.tolist(), + "area": float(size_xy[0] * size_xy[1]), + } + + +def _table_fit_uniform_xy_scale_transform( + *, + center_xy: np.ndarray, + scale: float, +) -> np.ndarray: + center = np.eye(4, dtype=np.float64) + center[:3, 3] = [float(center_xy[0]), float(center_xy[1]), 0.0] + uncenter = np.eye(4, dtype=np.float64) + uncenter[:3, 3] = [-float(center_xy[0]), -float(center_xy[1]), 0.0] + scale_mat = np.eye(4, dtype=np.float64) + scale_mat[0, 0] = float(scale) + scale_mat[1, 1] = float(scale) + return center @ scale_mat @ uncenter + + +def _table_fit_safe_positive_ratio(numerator: float, denominator: float) -> float: + return max(float(numerator) / max(float(denominator), 1.0e-6), 1.0e-6) + + +def _table_fit_convex_hull_2d(points: np.ndarray) -> np.ndarray: + unique = sorted({(float(x), float(y)) for x, y in np.asarray(points)[:, :2]}) + if len(unique) <= 1: + return np.asarray(unique, dtype=np.float64) + + def cross( + o: tuple[float, float], + a: tuple[float, float], + b: tuple[float, float], + ) -> float: + return (a[0] - o[0]) * (b[1] - o[1]) - (a[1] - o[1]) * (b[0] - o[0]) + + lower: list[tuple[float, float]] = [] + for point in unique: + while len(lower) >= 2 and cross(lower[-2], lower[-1], point) <= 0.0: + lower.pop() + lower.append(point) + upper: list[tuple[float, float]] = [] + for point in reversed(unique): + while len(upper) >= 2 and cross(upper[-2], upper[-1], point) <= 0.0: + upper.pop() + upper.append(point) + return np.asarray(lower[:-1] + upper[:-1], dtype=np.float64) + + +def _table_fit_point_in_convex_polygon( + point: np.ndarray, + polygon: np.ndarray, +) -> bool: + previous = 0.0 + for index in range(len(polygon)): + a = polygon[index] + b = polygon[(index + 1) % len(polygon)] + cross = float(np.cross(b - a, point - a)) + if abs(cross) < 1.0e-9: + continue + if previous == 0.0: + previous = cross + elif cross * previous < -1.0e-9: + return False + return True + + +def _table_fit_rect_corners( + *, + center: np.ndarray, + width: float, + depth: float, +) -> np.ndarray: + half_w = 0.5 * float(width) + half_d = 0.5 * float(depth) + return np.asarray( + [ + [center[0] - half_w, center[1] - half_d], + [center[0] + half_w, center[1] - half_d], + [center[0] + half_w, center[1] + half_d], + [center[0] - half_w, center[1] + half_d], + ], + dtype=np.float64, + ) + + +def _table_fit_rot2(angle: float) -> np.ndarray: + c = float(np.cos(angle)) + s = float(np.sin(angle)) + return np.asarray([[c, -s], [s, c]], dtype=np.float64) diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/managers/geometry_manager/schemas.py b/embodichain/gen_sim/prompt2scene/agent_tools/managers/geometry_manager/schemas.py new file mode 100644 index 000000000..f001720fc --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/managers/geometry_manager/schemas.py @@ -0,0 +1,201 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +__all__ = [ + "AlignToAxisRequest", + "AlignToAxisResult", + "AlignXYLongAxisRequest", + "AlignXYLongAxisResult", + "CenterMeshRequest", + "NormalizeRequest", + "NormalizeResult", + "CenterMeshResult", + "ConvertUpAxisRequest", + "ConvertUpAxisResult", + "DetectTabletopRequest", + "DetectTabletopResult", + "ExportMeshRequest", + "ExportMeshResult", + "LoadMeshRequest", + "LoadMeshResult", + "PlaceAbovePlaneRequest", + "PlaceAbovePlaneResult", + "SupportPlaneCandidate", +] + + +@dataclass(frozen=True) +class SupportPlaneCandidate: + """Candidate planar tabletop support surface.""" + + normal: list[float] + center: list[float] + area: float + face_indices: list[int] + below_vertex_count: int + above_vertex_count: int + below_area_score: float + above_area_score: float + score: float + + +@dataclass(frozen=True) +class LoadMeshRequest: + """Request to load a GLB/mesh file.""" + + mesh_path: Path + + +@dataclass(frozen=True) +class LoadMeshResult: + """Result of loading a mesh file.""" + + mesh: Any + + +@dataclass(frozen=True) +class ExportMeshRequest: + """Request to export a mesh to a file.""" + + mesh: Any + output_path: Path + + +@dataclass(frozen=True) +class ExportMeshResult: + """Result of exporting a mesh.""" + + output_path: Path + + +@dataclass(frozen=True) +class ConvertUpAxisRequest: + """Request to convert a mesh from one up-axis convention to another.""" + + mesh: Any + input_up_axis: list[float] | None = None + output_up_axis: list[float] | None = None + + +@dataclass(frozen=True) +class ConvertUpAxisResult: + """Result of converting a mesh up-axis.""" + + mesh: Any + + +@dataclass(frozen=True) +class CenterMeshRequest: + """Request to center a mesh by its bounding-box center.""" + + mesh: Any + + +@dataclass(frozen=True) +class CenterMeshResult: + """Result of centering a mesh.""" + + mesh: Any + bbox_center: list[float] + + +@dataclass(frozen=True) +class AlignToAxisRequest: + """Request to rotate a mesh so a source axis aligns to a target axis.""" + + mesh: Any + source_axis: list[float] + target_axis: list[float] + + +@dataclass(frozen=True) +class AlignToAxisResult: + """Result of aligning a mesh vector to an axis.""" + + mesh: Any + + +@dataclass(frozen=True) +class PlaceAbovePlaneRequest: + """Request to translate a mesh so its AABB bottom sits above the XY plane.""" + + mesh: Any + clearance: float = 0.01 + + +@dataclass(frozen=True) +class PlaceAbovePlaneResult: + """Result of placing a mesh above the XY plane.""" + + mesh: Any + + +@dataclass(frozen=True) +class DetectTabletopRequest: + """Request to detect the most likely tabletop plane in a mesh.""" + + mesh: Any + normal_angle_tol_deg: float = 8.0 + plane_distance_tol: float | None = None + min_area_ratio: float = 0.02 + max_candidates: int = 24 + + +@dataclass(frozen=True) +class DetectTabletopResult: + """Result of detecting the tabletop plane with oriented normal.""" + + selected: SupportPlaneCandidate + oriented_normal: list[float] + candidates: list[SupportPlaneCandidate] + + +@dataclass(frozen=True) +class AlignXYLongAxisRequest: + """Request to align a mesh XY long axis to the Y axis via PCA.""" + + mesh: Any + face_indices: list[int] | None = None + + +@dataclass(frozen=True) +class AlignXYLongAxisResult: + """Result of PCA yaw alignment.""" + + mesh: Any + yaw_angle_degrees: float + + +@dataclass(frozen=True) +class NormalizeRequest: + """Request to normalize a mesh to a target size.""" + + mesh: Any + target_size: float = 1.0 + + +@dataclass(frozen=True) +class NormalizeResult: + """Result of normalizing a mesh.""" + + mesh: Any + scale_factor: float diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/managers/image_generation_manager/__init__.py b/embodichain/gen_sim/prompt2scene/agent_tools/managers/image_generation_manager/__init__.py new file mode 100644 index 000000000..c7a200a51 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/managers/image_generation_manager/__init__.py @@ -0,0 +1,35 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from embodichain.gen_sim.prompt2scene.agent_tools.managers.image_generation_manager.manager import ( + ASSET_IMAGE_PROMPT_SUFFIX, + ImageGenerationManager, +) +from embodichain.gen_sim.prompt2scene.agent_tools.managers.image_generation_manager.schemas import ( + ImageGenerationRequest, + ImageGenerationResult, + TextToAssetImageRequest, +) + +__all__ = [ + "ASSET_IMAGE_PROMPT_SUFFIX", + "ImageGenerationManager", + "ImageGenerationRequest", + "ImageGenerationResult", + "TextToAssetImageRequest", +] diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/managers/image_generation_manager/manager.py b/embodichain/gen_sim/prompt2scene/agent_tools/managers/image_generation_manager/manager.py new file mode 100644 index 000000000..6406f74d3 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/managers/image_generation_manager/manager.py @@ -0,0 +1,76 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from pathlib import Path + +from embodichain.gen_sim.prompt2scene.agent_tools.clients.image_generation_client import ( + ImageGenerationClient, + ImageGenerationError, + ImageGenerationServerRequest, +) +from embodichain.gen_sim.prompt2scene.agent_tools.managers.image_generation_manager.schemas import ( + ImageGenerationRequest, + ImageGenerationResult, + TextToAssetImageRequest, +) + +ASSET_IMAGE_PROMPT_SUFFIX = ( + "single isolated object, centered, fully visible, " + "on a high contrast colored background. " +) + + +class ImageGenerationManager: + """Image generation domain operations.""" + + def __init__(self, *, client: ImageGenerationClient | None = None) -> None: + self.client = client or ImageGenerationClient() + + def generate_image(self, request: ImageGenerationRequest) -> ImageGenerationResult: + output_path = request.output_path.expanduser().resolve() + response = self.client.generate( + ImageGenerationServerRequest( + prompt=request.prompt, + output_path=output_path, + ), + ) + if isinstance(response, ImageGenerationError): + raise RuntimeError(response.error_message) + + return ImageGenerationResult( + image_path=Path(response.result.image_path).expanduser().resolve(), + ) + + def generate_asset_image_from_text( + self, + request: TextToAssetImageRequest, + ) -> Path: + prompt = _build_asset_image_prompt(request.prompt) + result = self.generate_image( + ImageGenerationRequest(prompt=prompt, output_path=request.output_path) + ) + return result.image_path + + +def _build_asset_image_prompt(prompt: str) -> str: + prompt = prompt.strip() + if not prompt: + raise ValueError("Text-to-asset image prompt must be non-empty.") + if ASSET_IMAGE_PROMPT_SUFFIX in prompt: + return prompt + return f"{prompt}, {ASSET_IMAGE_PROMPT_SUFFIX}" diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/managers/image_generation_manager/schemas.py b/embodichain/gen_sim/prompt2scene/agent_tools/managers/image_generation_manager/schemas.py new file mode 100644 index 000000000..ac4a9cd7e --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/managers/image_generation_manager/schemas.py @@ -0,0 +1,43 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path + + +@dataclass(frozen=True) +class TextToAssetImageRequest: + """Request for generating an asset image from a text prompt.""" + + prompt: str + output_path: Path + + +@dataclass(frozen=True) +class ImageGenerationRequest: + """Request for generating one image from text.""" + + prompt: str + output_path: Path + + +@dataclass(frozen=True) +class ImageGenerationResult: + """Generated image path.""" + + image_path: Path diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/managers/image_scene_manager/__init__.py b/embodichain/gen_sim/prompt2scene/agent_tools/managers/image_scene_manager/__init__.py new file mode 100644 index 000000000..2ad8f11a5 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/managers/image_scene_manager/__init__.py @@ -0,0 +1,29 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from embodichain.gen_sim.prompt2scene.agent_tools.managers.image_scene_manager.alignment import ( + _export_support_aligned_layout_glbs, +) +from embodichain.gen_sim.prompt2scene.agent_tools.managers.image_scene_manager.manifests import ( + _write_multi_object_layout_manifests, +) + +__all__ = [ + "_export_support_aligned_layout_glbs", + "_write_multi_object_layout_manifests", +] diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/managers/image_scene_manager/alignment.py b/embodichain/gen_sim/prompt2scene/agent_tools/managers/image_scene_manager/alignment.py new file mode 100644 index 000000000..6d7084f44 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/managers/image_scene_manager/alignment.py @@ -0,0 +1,537 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + + +from __future__ import annotations + +import traceback +from pathlib import Path +from typing import Any + +import numpy as np + +from embodichain.gen_sim.prompt2scene.workflows.llm_output import ( + call_structured_json_model_step, +) +from embodichain.gen_sim.prompt2scene.agent_tools.managers.image_scene_manager.prompts import ( + build_up_down_flip_check_messages, +) +from embodichain.gen_sim.prompt2scene.agent_tools.managers.metric_scale_manager import ( + GlobalMetricScaleRequest, + MetricScaleManager, +) +from embodichain.gen_sim.prompt2scene.agent_tools.managers.image_scene_manager.schemas import ( + UP_DOWN_FLIP_CHECK_JSON_SCHEMA, +) + +UP_DOWN_FLIP_CHECK_CONFIDENCE_THRESHOLD = 0.6 +UNIFIED_SCENE_STEP = "unified_scene" +from embodichain.gen_sim.prompt2scene.agent_tools.managers.blender_rendering_manager import ( + BlenderRenderingManager, + RenderObjectScenesRequest, +) +from embodichain.gen_sim.prompt2scene.agent_tools.managers.matplotlib_manager import ( + MatplotlibManager, + RenderImageComparisonRequest, +) +from embodichain.gen_sim.prompt2scene.agent_tools.managers.geometry_manager.scene_geometry import ( + _aabb_center, + _copy_scene_with_transform, + _estimate_support_normal, + _load_scene_with_transform, + _matrix_from_json, + _rotation_between_vectors, + _scale_transform, + _scene_to_mesh, + _support_normal_flip_transform, + _xy_aabb_center, + _z_up_to_glb_y_up_transform, + _z_yaw_transform, +) +from embodichain.gen_sim.prompt2scene.utils.io import ( + relative_path, +) +from embodichain.gen_sim.prompt2scene.agent_tools.managers.optimization_manager import ( + _object_scenes_xy_aabb_manifest, + _settle_and_pack_object_footprints, +) + +__all__ = ["_export_support_aligned_layout_glbs"] + + +def _export_support_aligned_layout_glbs( + *, + table: dict[str, Any], + objects: list[dict[str, Any]], + spatial_relations: list[dict[str, Any]], + original_image_path: Path | None, + llm: Any | None, + output_dir: Path, + output_root: Path, +) -> dict[str, Any]: + """Export layout-baked GLBs aligned by support normal and left-right order.""" + try: + import trimesh + except ImportError as exc: + raise RuntimeError("Support-aligned GLB export requires trimesh.") from exc + + output_dir.mkdir(parents=True, exist_ok=True) + support_reference_path = _resolve_generated_path( + table.get("support_reference_geometry_path") or table.get("raw_geometry_path"), + output_root, + ) + object_paths = [ + ( + str(item["id"]), + _resolve_generated_path(item.get("raw_geometry_path"), output_root), + item.get("transform_matrix"), + ) + for item in objects + if item.get("raw_geometry_path") and item.get("transform_matrix") + ] + if not support_reference_path.is_file(): + raise FileNotFoundError( + f"Support reference table GLB not found: {support_reference_path}" + ) + support_reference_transform = _matrix_from_json( + table.get("support_reference_transform_matrix") + or table.get("transform_matrix"), + name="table.support_reference_transform_matrix", + ) + if not object_paths: + raise ValueError("No raw object GLBs with transform matrices available.") + + support_reference_scene = trimesh.load(support_reference_path, force="scene") + support_reference_scene.apply_transform(support_reference_transform) + object_scenes = [ + ( + object_id, + _load_scene_with_transform( + path=path, + transform=_matrix_from_json( + transform, + name=f"{object_id}.transform_matrix", + ), + trimesh=trimesh, + ), + ) + for object_id, path, transform in object_paths + ] + table_mesh = _scene_to_mesh(support_reference_scene, trimesh=trimesh) + support_normal = _estimate_support_normal(table_mesh) + normal_alignment = _rotation_between_vectors( + support_normal, + np.array([0.0, 0.0, 1.0]), + ) + + for _, scene in object_scenes: + scene.apply_transform(normal_alignment) + + object_bounds = [ + _scene_to_mesh(scene, trimesh=trimesh).bounds for _, scene in object_scenes + ] + clutter_bounds = np.vstack( + [ + np.vstack([bounds[0] for bounds in object_bounds]).min(axis=0), + np.vstack([bounds[1] for bounds in object_bounds]).max(axis=0), + ] + ) + clutter_center = 0.5 * (clutter_bounds[0] + clutter_bounds[1]) + center_transform = np.eye(4, dtype=np.float64) + center_transform[:3, 3] = [ + -float(clutter_center[0]), + -float(clutter_center[1]), + -float(clutter_center[2]), + ] + + for _, scene in object_scenes: + scene.apply_transform(center_transform) + + alignment_candidates = _build_up_down_alignment_candidates( + object_scenes=object_scenes, + support_normal=support_normal, + normal_alignment=normal_alignment, + spatial_relations=spatial_relations, + trimesh=trimesh, + ) + vlm_check_dir = output_dir / "vlm_up_down_flip_check" + up_down_flip_check_result = _run_aligned_up_down_flip_vlm_check( + llm=llm, + original_image_path=original_image_path, + normal_object_scenes=alignment_candidates["normal"]["object_scenes"], + flipped_object_scenes=alignment_candidates["flipped"]["object_scenes"], + output_dir=vlm_check_dir, + ) + selected_variant = str( + up_down_flip_check_result.get("selected_variant") or "normal" + ) + if selected_variant not in alignment_candidates: + selected_variant = "normal" + selected_candidate = alignment_candidates[selected_variant] + object_scenes = selected_candidate["object_scenes"] + selected_extra_transform = selected_candidate["extra_transform"] + apply_up_down_flip = selected_variant == "flipped" + + global_metric_scale = MetricScaleManager.compute_global_from_object_scenes( + GlobalMetricScaleRequest( + objects=objects, + object_scenes=object_scenes, + ) + ) + metric_scale_transform = _scale_transform(global_metric_scale["scale_factor"]) + if float(global_metric_scale["scale_factor"]) != 1.0: + for _, scene in object_scenes: + scene.apply_transform(metric_scale_transform) + + footprint_result = _settle_and_pack_object_footprints( + object_scenes=object_scenes, + output_dir=output_dir / "footprint_layout", + output_root=output_root, + trimesh=trimesh, + ) + object_scenes = footprint_result["object_scenes"] + + output_axis_transform = _z_up_to_glb_y_up_transform() + object_outputs = [] + for object_id, scene in object_scenes: + object_output = output_dir / f"{object_id}_aligned.glb" + _copy_scene_with_transform(scene, output_axis_transform).export(object_output) + object_outputs.append( + { + "id": object_id, + "aligned_geometry_path": relative_path(str(object_output), output_root), + } + ) + + alignment_matrix = selected_extra_transform @ center_transform @ normal_alignment + scaled_alignment_matrix = metric_scale_transform @ alignment_matrix + final_clutter_aabb_2d_cm = _object_scenes_xy_aabb_manifest( + object_scenes=object_scenes, + trimesh=trimesh, + unit_scale=100.0, + unit="cm", + ) + return { + "status": "ok", + "output_dir": relative_path(str(output_dir), output_root), + "support_normal": support_normal.tolist(), + "clutter_aabb_center_before_centering": clutter_center.tolist(), + "alignment_matrix": scaled_alignment_matrix.tolist(), + "pre_metric_scale_alignment_matrix": alignment_matrix.tolist(), + "global_metric_scale": global_metric_scale, + "final_clutter_2d_aabb_cm": final_clutter_aabb_2d_cm, + "internal_up_axis": [0.0, 0.0, 1.0], + "glb_output_up_axis": [0.0, 1.0, 0.0], + "glb_output_axis_transform": output_axis_transform.tolist(), + "selected_up_down_variant": selected_variant, + "applied_up_down_flip": apply_up_down_flip, + "selected_extra_transform": selected_extra_transform.tolist(), + "object_alignment_matrices": { + object_id: (object_transform @ scaled_alignment_matrix).tolist() + for object_id, object_transform in footprint_result[ + "object_layout_transforms" + ].items() + }, + "footprint_layout": footprint_result["manifest"], + "yaw_sampling": { + "sample_count_per_variant": 360, + "score_type": "center_left_of_hard_count", + "top_view_plane": "XY", + "yaw_axis": "Z", + "left_right_axis": "X", + "front_back_axis": "Y", + "front_direction": "+Y", + "normal": alignment_candidates["normal"]["yaw_metadata"], + "flipped": alignment_candidates["flipped"]["yaw_metadata"], + }, + "up_down_flip_check": up_down_flip_check_result, + "objects": object_outputs, + } + + +def _build_up_down_alignment_candidates( + *, + object_scenes: list[tuple[str, Any]], + support_normal: np.ndarray, + normal_alignment: np.ndarray, + spatial_relations: list[dict[str, Any]], + trimesh: Any, +) -> dict[str, dict[str, Any]]: + flip_transform = _support_normal_flip_transform( + support_normal=support_normal, + normal_alignment=normal_alignment, + ) + directional_relations = _spatial_directional_relations(spatial_relations) + candidates: dict[str, dict[str, Any]] = {} + for variant, pre_yaw_transform in [ + ("normal", np.eye(4, dtype=np.float64)), + ("flipped", flip_transform), + ]: + candidate_object_scenes = [ + (object_id, _copy_scene_with_transform(scene, pre_yaw_transform)) + for object_id, scene in object_scenes + ] + object_bounds = { + object_id: np.asarray( + _scene_to_mesh(scene, trimesh=trimesh).bounds, + dtype=np.float64, + ) + for object_id, scene in candidate_object_scenes + } + yaw_metadata = _best_spatial_yaw( + object_bounds=object_bounds, + relations=directional_relations, + ) + yaw_transform = _z_yaw_transform( + float(yaw_metadata["yaw_degrees"]), + ) + for _, scene in candidate_object_scenes: + scene.apply_transform(yaw_transform) + candidates[variant] = { + "object_scenes": candidate_object_scenes, + "pre_yaw_transform": pre_yaw_transform, + "yaw_transform": yaw_transform, + "extra_transform": yaw_transform @ pre_yaw_transform, + "yaw_metadata": yaw_metadata, + } + return candidates + + +def _best_spatial_yaw( + *, + object_bounds: dict[str, np.ndarray], + relations: list[dict[str, str]], +) -> dict[str, Any]: + if not relations: + return { + "yaw_degrees": 0, + "score": 0, + "raw_gap_sum": 0.0, + "relation_count": 0, + "score_type": "center_left_of_hard_count", + } + + object_centers = { + object_id: _aabb_center(bounds) for object_id, bounds in object_bounds.items() + } + best_yaw = 0 + best_score = -1 + best_raw_gap_sum = float("-inf") + best_relation_scores: list[dict[str, Any]] = [] + for yaw_degrees in range(360): + rotation = _z_yaw_transform(float(yaw_degrees)) + rotated_centers = { + object_id: _transform_point(rotation, center) + for object_id, center in object_centers.items() + } + score, raw_gap_sum, relation_scores = _center_left_of_score( + centers=rotated_centers, + relations=relations, + ) + if score > best_score or ( + score == best_score and raw_gap_sum > best_raw_gap_sum + ): + best_yaw = yaw_degrees + best_score = score + best_raw_gap_sum = raw_gap_sum + best_relation_scores = relation_scores + return { + "yaw_degrees": best_yaw, + "score": best_score, + "raw_gap_sum": best_raw_gap_sum, + "relation_count": len(relations), + "score_type": "center_left_of_hard_count", + "relation_scores": best_relation_scores, + } + + +def _spatial_directional_relations( + spatial_relations: list[dict[str, Any]], +) -> list[dict[str, str]]: + relations: list[dict[str, str]] = [] + seen: set[tuple[str, str, str]] = set() + for relation in spatial_relations: + subject = str(relation.get("subject") or "") + object_id = str(relation.get("object") or "") + relation_name = str(relation.get("relation") or "") + if ( + not subject + or not object_id + or subject == object_id + or relation_name != "left_of" + ): + continue + key = (subject, relation_name, object_id) + if key in seen: + continue + seen.add(key) + relations.append( + { + "subject": subject, + "relation": relation_name, + "object": object_id, + } + ) + return relations + + +def _center_left_of_score( + centers: dict[str, np.ndarray], + relations: list[dict[str, str]], +) -> tuple[int, float, list[dict[str, Any]]]: + score = 0 + raw_gap_sum = 0.0 + relation_scores: list[dict[str, Any]] = [] + for relation in relations: + subject = relation["subject"] + object_id = relation["object"] + if subject not in centers or object_id not in centers: + continue + subject_center = centers[subject] + object_center = centers[object_id] + gap = float(object_center[0] - subject_center[0]) + relation_score = 1 if gap > 0.0 else 0 + score += relation_score + raw_gap_sum += gap + relation_scores.append( + { + "subject": subject, + "relation": "left_of", + "object": object_id, + "gap": gap, + "score": relation_score, + } + ) + return score, raw_gap_sum, relation_scores + + +def _transform_point(transform: np.ndarray, point: np.ndarray) -> np.ndarray: + homogeneous = np.ones(4, dtype=np.float64) + homogeneous[:3] = point + return (transform @ homogeneous)[:3] + + +def _resolve_generated_path(value: Any, output_root: Path) -> Path: + path = Path(str(value or "")).expanduser() + if path.is_absolute(): + return path.resolve() + return (output_root / path).resolve() + + +def _run_aligned_up_down_flip_vlm_check( + *, + llm: Any | None, + original_image_path: Path | None, + normal_object_scenes: list[tuple[str, Any]], + flipped_object_scenes: list[tuple[str, Any]], + output_dir: Path, +) -> dict[str, Any]: + output_dir.mkdir(parents=True, exist_ok=True) + result: dict[str, Any] = { + "status": "skipped", + "applied_up_down_flip": False, + "confidence_threshold": UP_DOWN_FLIP_CHECK_CONFIDENCE_THRESHOLD, + "reason": "", + } + if not normal_object_scenes or not flipped_object_scenes: + result["reason"] = "missing_object_scenes" + return result + + try: + normal_render_path = output_dir / "normal_object_only_front_oblique_view.png" + flipped_render_path = output_dir / "flipped_object_only_front_oblique_view.png" + comparison_image_path = output_dir / "numbered_up_down_candidates.png" + BlenderRenderingManager().render_object_scenes( + RenderObjectScenesRequest( + object_scenes=normal_object_scenes, + output_path=normal_render_path, + ) + ) + BlenderRenderingManager().render_object_scenes( + RenderObjectScenesRequest( + object_scenes=flipped_object_scenes, + output_path=flipped_render_path, + ) + ) + MatplotlibManager(figsize=(12, 6), dpi=180).render_image_comparison( + RenderImageComparisonRequest( + first_image_path=normal_render_path, + second_image_path=flipped_render_path, + output_path=comparison_image_path, + ) + ) + if llm is None: + result["reason"] = "missing_llm" + return result + if original_image_path is None or not original_image_path.is_file(): + result["reason"] = "missing_original_image" + return result + + raw_model_output = call_structured_json_model_step( + llm=llm, + schema=UP_DOWN_FLIP_CHECK_JSON_SCHEMA, + messages=build_up_down_flip_check_messages( + original_image_path=original_image_path, + comparison_image_path=comparison_image_path, + ), + context="Unified scene aligned up-down flip check", + step_name=UNIFIED_SCENE_STEP, + output_root=None, + attempt_count=0, + ) + # Persist VLM raw output alongside the comparison renders + try: + import json as _json + + vlm_result_path = output_dir / "vlm_flip_check_result.json" + vlm_result_path.write_text( + _json.dumps(raw_model_output, ensure_ascii=False, indent=2), + encoding="utf-8", + ) + except Exception: + pass + confidence = float(raw_model_output.get("confidence", 0.0)) + selected_number = int(raw_model_output.get("selected_number", 1)) + if selected_number not in {1, 2}: + selected_number = 1 + model_selected_variant = "flipped" if selected_number == 2 else "normal" + should_apply = ( + model_selected_variant == "flipped" + and confidence >= UP_DOWN_FLIP_CHECK_CONFIDENCE_THRESHOLD + ) + selected_variant = "flipped" if should_apply else "normal" + selected_number = 2 if selected_variant == "flipped" else 1 + result.update( + { + "status": "ok", + "selected_number": selected_number, + "selected_variant": selected_variant, + "applied_up_down_flip": should_apply, + "model_selected_number": raw_model_output.get("selected_number"), + "model_selected_variant": model_selected_variant, + "confidence": confidence, + "reason": str(raw_model_output.get("reason", "")), + } + ) + return result + except Exception: + result.update( + { + "status": "failed", + "reason": traceback.format_exc(), + } + ) + return result diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/managers/image_scene_manager/manifests.py b/embodichain/gen_sim/prompt2scene/agent_tools/managers/image_scene_manager/manifests.py new file mode 100644 index 000000000..6ae379c3e --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/managers/image_scene_manager/manifests.py @@ -0,0 +1,212 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + + +from __future__ import annotations + +from pathlib import Path +from typing import Any + +from embodichain.gen_sim.prompt2scene.utils.io import ( + relative_path, +) +from embodichain.gen_sim.prompt2scene.agent_tools.managers.geometry_manager.scene_geometry import ( + _compose_json_matrices, + _compose_simready_to_aligned_matrix, + _decompose_transform_matrix, +) +from embodichain.gen_sim.prompt2scene.utils.io import write_json + +__all__ = ["_write_multi_object_layout_manifests"] + + +def _write_multi_object_layout_manifests( + *, + glb_gen_dir: Path, + output_root: Path, + table: dict[str, Any] | None, + objects: list[dict[str, Any]], + alignment: dict[str, Any] | None, +) -> dict[str, str]: + simready_to_aligned_path = glb_gen_dir / "simready_to_aligned_manifest.json" + + write_json( + simready_to_aligned_path, + _simready_to_aligned_manifest( + table=table, + items=objects, + alignment=alignment, + output_root=output_root, + ), + ) + return { + "simready_to_aligned_manifest_path": relative_path( + str(simready_to_aligned_path), + output_root, + ), + } + + +def _simready_to_aligned_manifest( + *, + table: dict[str, Any] | None, + items: list[dict[str, Any]], + alignment: dict[str, Any] | None, + output_root: Path, +) -> dict[str, Any]: + alignment = alignment or {} + alignment_matrix = alignment.get("alignment_matrix", []) + glb_output_axis_transform = alignment.get("glb_output_axis_transform", []) + object_alignment_matrices = alignment.get("object_alignment_matrices", {}) + aligned_by_id = _aligned_outputs_by_id(alignment) + return { + "note": ( + "Aligned GLBs are generated from raw_downloads plus SAM3D layout " + "matrices in memory; simready paths are recorded here as the " + "simulation-ready counterpart for each raw GLB." + ), + "alignment_status": alignment.get("status", ""), + "alignment_reason": alignment.get("reason", ""), + "selected_up_down_variant": alignment.get("selected_up_down_variant", ""), + "applied_up_down_flip": alignment.get("applied_up_down_flip", False), + "alignment_matrix": alignment_matrix, + "global_metric_scale": alignment.get("global_metric_scale"), + "final_clutter_2d_aabb_cm": alignment.get("final_clutter_2d_aabb_cm"), + "glb_output_axis_transform": glb_output_axis_transform, + "table": ( + _simready_manifest_table_item(table, output_root=output_root) + if table is not None + else None + ), + "items": [ + _simready_to_aligned_manifest_item( + item, + aligned_by_id=aligned_by_id, + alignment_matrix=alignment_matrix, + object_alignment_matrices=object_alignment_matrices, + glb_output_axis_transform=glb_output_axis_transform, + output_root=output_root, + ) + for item in items + ], + } + + +def _aligned_outputs_by_id(alignment: dict[str, Any]) -> dict[str, str]: + outputs: dict[str, str] = {} + for item in alignment.get("objects", []) or []: + if isinstance(item, dict) and item.get("id"): + outputs[str(item["id"])] = str(item.get("aligned_geometry_path", "")) + return outputs + + +def _simready_manifest_table_item( + item: dict[str, Any], + *, + output_root: Path, +) -> dict[str, Any]: + return { + "id": item.get("id", ""), + "name": item.get("name", ""), + "kind": item.get("kind", "table"), + "status": item.get("status", ""), + "simready_geometry_path": ( + relative_path( + str( + _resolve_generated_path( + item.get("simready_geometry_path"), output_root + ) + ), + output_root, + ) + if item.get("simready_geometry_path") + else "" + ), + "support_reference_geometry_path": ( + relative_path( + str( + _resolve_generated_path( + item.get("support_reference_geometry_path"), + output_root, + ) + ), + output_root, + ) + if item.get("support_reference_geometry_path") + else "" + ), + "table_asset_source": item.get("table_asset_source", ""), + "support_normal_source": item.get("support_normal_source", ""), + "is_complete_visible_table": item.get("is_complete_visible_table", False), + "complete_table_description": item.get("complete_table_description", ""), + } + + +def _simready_to_aligned_manifest_item( + item: dict[str, Any], + *, + aligned_by_id: dict[str, str], + alignment_matrix: Any, + object_alignment_matrices: Any, + glb_output_axis_transform: Any, + output_root: Path, +) -> dict[str, Any]: + item_id = str(item.get("id", "")) + sam3d_transform = item.get("transform_matrix", []) + item_alignment_matrix = alignment_matrix + if isinstance(object_alignment_matrices, dict): + item_alignment_matrix = object_alignment_matrices.get( + item_id, + alignment_matrix, + ) + raw_to_aligned_matrix = _compose_json_matrices( + glb_output_axis_transform, + item_alignment_matrix, + sam3d_transform, + ) + simready_to_aligned_matrix = _compose_simready_to_aligned_matrix( + raw_to_aligned_matrix=raw_to_aligned_matrix, + raw_to_simready_matrix=item.get("raw_to_simready_glb_matrix", []), + ) + decomposed = _decompose_transform_matrix(simready_to_aligned_matrix) + return { + "id": item_id, + "name": item.get("name", ""), + "kind": item.get("kind", ""), + "simready_geometry_path": item.get("simready_geometry_path", ""), + "aligned_geometry_path": aligned_by_id.get(item_id, ""), + "metric_scale": _trim_metric_scale(item.get("metric_scale")), + "simready_to_aligned_matrix": simready_to_aligned_matrix, + "translation": decomposed["translation"], + "rotation_matrix": decomposed["rotation_matrix"], + "scale": decomposed["scale"], + } + + +def _trim_metric_scale(value: Any) -> dict[str, Any] | None: + if not isinstance(value, dict): + return None + metric_scale = dict(value) + for key in ["result_path", "raw_model_output_path"]: + metric_scale.pop(key, None) + return metric_scale + + +def _resolve_generated_path(value: Any, output_root: Path) -> Path: + path = Path(str(value or "")).expanduser() + if path.is_absolute(): + return path.resolve() + return (output_root / path).resolve() diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/managers/image_scene_manager/prompts.py b/embodichain/gen_sim/prompt2scene/agent_tools/managers/image_scene_manager/prompts.py new file mode 100644 index 000000000..85b41388b --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/managers/image_scene_manager/prompts.py @@ -0,0 +1,106 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +import json +from pathlib import Path +from typing import Any + +from embodichain.gen_sim.prompt2scene.prompts import render_prompt +from embodichain.gen_sim.prompt2scene.utils.io import image_to_data_url + +__all__ = [ + "build_image_metric_scale_messages", + "build_up_down_flip_check_messages", +] + +UNIFIED_SCENE_GEN_PROMPT_NAME = "unified_scene_gen.yaml" + + +def build_image_metric_scale_messages( + *, + bbox_name_image_path: Path, + objects_json: list[dict[str, Any]], +) -> list[dict[str, Any]]: + return [ + { + "role": "system", + "content": render_prompt( + UNIFIED_SCENE_GEN_PROMPT_NAME, + prompt_key="image_metric_scale_system", + ), + }, + { + "role": "user", + "content": [ + { + "type": "text", + "text": render_prompt( + UNIFIED_SCENE_GEN_PROMPT_NAME, + { + "objects_json": json.dumps( + objects_json, + ensure_ascii=False, + indent=2, + ), + }, + prompt_key="image_metric_scale_user", + ), + }, + { + "type": "image_url", + "image_url": {"url": image_to_data_url(bbox_name_image_path)}, + }, + ], + }, + ] + + +def build_up_down_flip_check_messages( + *, + original_image_path: Path, + comparison_image_path: Path, +) -> list[dict[str, Any]]: + return [ + { + "role": "system", + "content": render_prompt( + UNIFIED_SCENE_GEN_PROMPT_NAME, + prompt_key="up_down_flip_check_system", + ), + }, + { + "role": "user", + "content": [ + { + "type": "text", + "text": render_prompt( + UNIFIED_SCENE_GEN_PROMPT_NAME, + prompt_key="up_down_flip_check_user", + ), + }, + { + "type": "image_url", + "image_url": {"url": image_to_data_url(original_image_path)}, + }, + { + "type": "image_url", + "image_url": {"url": image_to_data_url(comparison_image_path)}, + }, + ], + }, + ] diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/managers/image_scene_manager/schemas.py b/embodichain/gen_sim/prompt2scene/agent_tools/managers/image_scene_manager/schemas.py new file mode 100644 index 000000000..b22fcebba --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/managers/image_scene_manager/schemas.py @@ -0,0 +1,71 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from typing import Any + +__all__ = [ + "IMAGE_METRIC_SCALE_JSON_SCHEMA", + "UP_DOWN_FLIP_CHECK_JSON_SCHEMA", +] + +UP_DOWN_FLIP_CHECK_JSON_SCHEMA: dict[str, Any] = { + "title": "AlignedUpDownFlipCheckOutput", + "type": "object", + "additionalProperties": False, + "properties": { + "selected_number": {"type": "integer", "enum": [1, 2]}, + "confidence": {"type": "number", "minimum": 0.0, "maximum": 1.0}, + "reason": {"type": "string"}, + }, + "required": ["selected_number", "confidence", "reason"], +} + +IMAGE_METRIC_SCALE_JSON_SCHEMA: dict[str, Any] = { + "title": "ImageMetricScaleEstimate", + "type": "object", + "additionalProperties": False, + "properties": { + "object_scales": { + "type": "array", + "items": { + "type": "object", + "additionalProperties": False, + "properties": { + "object_id": {"type": "string"}, + "bbox_dims_cm": { + "type": "array", + "minItems": 3, + "maxItems": 3, + "items": { + "type": "number", + "minimum": 1.0e-6, + }, + }, + "confidence": { + "type": "number", + "minimum": 0.0, + "maximum": 1.0, + }, + "reason": {"type": "string"}, + }, + "required": ["object_id", "bbox_dims_cm", "confidence", "reason"], + }, + }, + }, + "required": ["object_scales"], +} diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/managers/image_segmentation_manager/__init__.py b/embodichain/gen_sim/prompt2scene/agent_tools/managers/image_segmentation_manager/__init__.py new file mode 100644 index 000000000..fbbf31487 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/managers/image_segmentation_manager/__init__.py @@ -0,0 +1,33 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from embodichain.gen_sim.prompt2scene.agent_tools.managers.image_segmentation_manager.manager import ( + ImageSegmentationManager, +) +from embodichain.gen_sim.prompt2scene.agent_tools.managers.image_segmentation_manager.schemas import ( + AssetImageToRgbaRequest, + ImageSegmentationRequest, + ImageSegmentationResult, +) + +__all__ = [ + "AssetImageToRgbaRequest", + "ImageSegmentationManager", + "ImageSegmentationRequest", + "ImageSegmentationResult", +] diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/managers/image_segmentation_manager/manager.py b/embodichain/gen_sim/prompt2scene/agent_tools/managers/image_segmentation_manager/manager.py new file mode 100644 index 000000000..052b8d7db --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/managers/image_segmentation_manager/manager.py @@ -0,0 +1,90 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from pathlib import Path + +from embodichain.gen_sim.prompt2scene.agent_tools.clients.image_segmentation_client import ( + ImageSegmentationClient, + ImageSegmentationError, + ImageSegmentationServerRequest, + apply_mask_to_alpha, + decode_rle_mask, +) +from embodichain.gen_sim.prompt2scene.agent_tools.managers.image_segmentation_manager.schemas import ( + AssetImageToRgbaRequest, + ImageSegmentationRequest, + ImageSegmentationResult, +) + + +class ImageSegmentationManager: + """Image segmentation domain operations.""" + + def __init__(self, *, client: ImageSegmentationClient | None = None) -> None: + self.client = client or ImageSegmentationClient() + + def segment_image( + self, + request: ImageSegmentationRequest, + ) -> ImageSegmentationResult: + image_path = request.image_path.expanduser().resolve() + _validate_segment_request(image_path=image_path, prompt=request.prompt) + + response = self.client.segment( + ImageSegmentationServerRequest( + prompt=request.prompt.strip(), + image_path=image_path, + ), + ) + if isinstance(response, ImageSegmentationError): + raise RuntimeError(response.error_message) + + return ImageSegmentationResult(candidates=list(response.result.candidates)) + + def convert_asset_image_to_rgba( + self, + request: AssetImageToRgbaRequest, + ) -> Path: + segmentation_result = self.segment_image( + ImageSegmentationRequest( + image_path=request.image_path, + prompt=request.prompt, + ) + ) + if not segmentation_result.candidates: + raise ValueError("Image segmentation returned no candidates.") + + candidate = segmentation_result.candidates[0] + if candidate.mask_rle is None: + raise ValueError(f"Candidate {candidate.candidate_id} has no mask_rle.") + + output_path = request.output_path.expanduser().resolve() + output_path.parent.mkdir(parents=True, exist_ok=True) + mask = decode_rle_mask(candidate.mask_rle) + rgba = apply_mask_to_alpha(request.image_path, mask) + rgba.save(output_path) + if not output_path.is_file(): + raise FileNotFoundError(f"RGBA image was not written: {output_path}") + return output_path + + +def _validate_segment_request(*, image_path: Path, prompt: str) -> None: + if not image_path.is_file(): + raise FileNotFoundError(f"Image segmentation input not found: {image_path}") + if not prompt.strip(): + raise ValueError("Image segmentation prompt must be non-empty.") diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/managers/image_segmentation_manager/schemas.py b/embodichain/gen_sim/prompt2scene/agent_tools/managers/image_segmentation_manager/schemas.py new file mode 100644 index 000000000..d59b7e7a9 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/managers/image_segmentation_manager/schemas.py @@ -0,0 +1,48 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path + +from embodichain.gen_sim.prompt2scene.agent_tools.clients.image_segmentation_client import ( + ImageSegmentationCandidate, +) + + +@dataclass(frozen=True) +class AssetImageToRgbaRequest: + """Request for converting an asset image to an RGBA cutout.""" + + image_path: Path + prompt: str + output_path: Path + + +@dataclass(frozen=True) +class ImageSegmentationRequest: + """Request for segmenting one image with one text prompt.""" + + image_path: Path + prompt: str + + +@dataclass(frozen=True) +class ImageSegmentationResult: + """Segmentation candidates.""" + + candidates: list[ImageSegmentationCandidate] diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/managers/matplotlib_manager/__init__.py b/embodichain/gen_sim/prompt2scene/agent_tools/managers/matplotlib_manager/__init__.py new file mode 100644 index 000000000..21cf6c253 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/managers/matplotlib_manager/__init__.py @@ -0,0 +1,43 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from embodichain.gen_sim.prompt2scene.agent_tools.managers.matplotlib_manager.manager import ( + MatplotlibManager, +) +from embodichain.gen_sim.prompt2scene.agent_tools.managers.matplotlib_manager.schemas import ( + RenderFootprintLayoutRequest, + RenderFootprintLayoutResult, + RenderImageComparisonRequest, + RenderImageComparisonResult, + RenderSupportRegionRequest, + RenderSupportRegionResult, + RenderXYComparisonRequest, + RenderXYComparisonResult, +) + +__all__ = [ + "MatplotlibManager", + "RenderFootprintLayoutRequest", + "RenderFootprintLayoutResult", + "RenderImageComparisonRequest", + "RenderImageComparisonResult", + "RenderSupportRegionRequest", + "RenderSupportRegionResult", + "RenderXYComparisonRequest", + "RenderXYComparisonResult", +] diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/managers/matplotlib_manager/manager.py b/embodichain/gen_sim/prompt2scene/agent_tools/managers/matplotlib_manager/manager.py new file mode 100644 index 000000000..1feb13c3f --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/managers/matplotlib_manager/manager.py @@ -0,0 +1,401 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +"""Matplotlib manager for mesh visualization.""" + +from __future__ import annotations + +from pathlib import Path +from typing import Any + +import numpy as np +import matplotlib + +matplotlib.use("Agg") +import matplotlib.pyplot as plt +from matplotlib.collections import PolyCollection +from matplotlib.patches import Rectangle +from mpl_toolkits.mplot3d.art3d import Poly3DCollection + +from embodichain.gen_sim.prompt2scene.agent_tools.managers.matplotlib_manager.schemas import ( + RenderFootprintLayoutRequest, + RenderFootprintLayoutResult, + RenderImageComparisonRequest, + RenderImageComparisonResult, + RenderSupportRegionRequest, + RenderSupportRegionResult, + RenderXYComparisonRequest, + RenderXYComparisonResult, +) + +__all__ = ["MatplotlibManager"] + + +class MatplotlibManager: + """Manager for mesh visualization via matplotlib. + + Wraps matplotlib rendering with typed request/response methods, + following the same pattern as service clients. + """ + + def __init__( + self, + *, + figsize: tuple[float, float] = (8, 8), + dpi: int = 180, + ) -> None: + """Initialize the matplotlib manager. + + Args: + figsize: Default figure size for rendered images. + dpi: Output image resolution. + """ + self._figsize = figsize + self._dpi = dpi + + def render_footprint_layout( + self, + request: RenderFootprintLayoutRequest, + ) -> RenderFootprintLayoutResult: + """Render labeled XY footprints with full-length coordinate axes.""" + output_path = request.output_path.expanduser().resolve() + output_path.parent.mkdir(parents=True, exist_ok=True) + if not request.object_ids: + return RenderFootprintLayoutResult(output_path=output_path) + + centers = { + object_id: np.asarray(request.centers[object_id], dtype=float) + for object_id in request.object_ids + } + sizes = { + object_id: np.asarray(request.xy_sizes[object_id], dtype=float) + for object_id in request.object_ids + } + footprint_mins = np.vstack( + [ + centers[object_id] - 0.5 * sizes[object_id] + for object_id in request.object_ids + ] + ) + footprint_maxs = np.vstack( + [ + centers[object_id] + 0.5 * sizes[object_id] + for object_id in request.object_ids + ] + ) + data_min = footprint_mins.min(axis=0) + data_max = footprint_maxs.max(axis=0) + span = np.maximum(data_max - data_min, 1.0e-6) + padding = max(float(span.max()) * 0.12, 1.0e-3) + x_limits = (float(data_min[0] - padding), float(data_max[0] + padding)) + y_limits = (float(data_min[1] - padding), float(data_max[1] + padding)) + + fig, ax = plt.subplots(figsize=self._figsize) + for object_id in request.object_ids: + center = centers[object_id] + size = sizes[object_id] + ax.add_patch( + Rectangle( + (center[0] - 0.5 * size[0], center[1] - 0.5 * size[1]), + size[0], + size[1], + facecolor=(0.35, 0.60, 0.95, 0.30), + edgecolor=(0.08, 0.22, 0.60, 1.0), + linewidth=1.5, + ) + ) + label = object_id.replace("interact_", "").removesuffix("_0") + ax.text( + center[0], + center[1], + label, + ha="center", + va="center", + fontsize=9, + color="black", + ) + + self._draw_full_xy_axes(ax, x_limits=x_limits, y_limits=y_limits) + ax.set_xlim(*x_limits) + ax.set_ylim(*y_limits) + ax.set_aspect("equal", adjustable="box") + ax.set_title(request.title) + ax.grid(True, linestyle=":", linewidth=0.6, alpha=0.30) + fig.tight_layout() + fig.savefig(output_path, dpi=self._dpi) + plt.close(fig) + return RenderFootprintLayoutResult(output_path=output_path) + + def render_image_comparison( + self, + request: RenderImageComparisonRequest, + ) -> RenderImageComparisonResult: + """Render two images side by side with numbered labels.""" + output_path = request.output_path.expanduser().resolve() + output_path.parent.mkdir(parents=True, exist_ok=True) + first_image = plt.imread(request.first_image_path.expanduser().resolve()) + second_image = plt.imread(request.second_image_path.expanduser().resolve()) + + fig, axes = plt.subplots(1, 2, figsize=(12, 6)) + for ax, image, label in ( + (axes[0], first_image, request.first_label), + (axes[1], second_image, request.second_label), + ): + ax.imshow(image) + ax.set_title(label, fontsize=16, loc="left") + ax.axis("off") + fig.tight_layout() + fig.savefig(output_path, dpi=self._dpi, facecolor="white") + plt.close(fig) + return RenderImageComparisonResult(output_path=output_path) + + @staticmethod + def _draw_full_xy_axes( + ax: Any, + *, + x_limits: tuple[float, float], + y_limits: tuple[float, float], + ) -> None: + """Draw axes across the full viewport, centered on the data bounds.""" + axis_color = "#303030" + x_center = 0.5 * (x_limits[0] + x_limits[1]) + y_center = 0.5 * (y_limits[0] + y_limits[1]) + # Horizontal axis (X) — spans full width, positioned at vertical centre. + ax.annotate( + "", + xy=(x_limits[1], y_center), + xytext=(x_limits[0], y_center), + arrowprops={"arrowstyle": "->", "color": axis_color, "lw": 1.8}, + zorder=8, + ) + # Vertical axis (Y) — spans full height, positioned at horizontal centre. + ax.annotate( + "", + xy=(x_center, y_limits[1]), + xytext=(x_center, y_limits[0]), + arrowprops={"arrowstyle": "->", "color": axis_color, "lw": 1.8}, + zorder=8, + ) + x_span = x_limits[1] - x_limits[0] + y_span = y_limits[1] - y_limits[0] + ax.text( + x_limits[1] - 0.03 * x_span, + y_center + 0.02 * y_span, + "+X", + ha="right", + va="bottom", + color=axis_color, + fontsize=11, + ) + ax.text( + x_center + 0.02 * x_span, + y_limits[1] - 0.03 * y_span, + "+Y", + ha="left", + va="top", + color=axis_color, + fontsize=11, + ) + # Mark the origin at the centre. + ax.plot(x_center, y_center, "o", color=axis_color, markersize=6, zorder=9) + ax.text( + x_center + 0.015 * x_span, + y_center + 0.015 * y_span, + "Origin", + fontsize=8, + color=axis_color, + ha="left", + va="bottom", + zorder=9, + ) + + def render_selected_support_region( + self, request: RenderSupportRegionRequest + ) -> RenderSupportRegionResult: + """Render a mesh with the selected support region highlighted.""" + output_path = request.output_path.expanduser().resolve() + output_path.parent.mkdir(parents=True, exist_ok=True) + + vertices = np.asarray(request.mesh.vertices, dtype=float) + faces = np.asarray(request.mesh.faces, dtype=int) + selected_faces = faces[np.asarray(request.face_indices, dtype=int)] + + fig = plt.figure(figsize=self._figsize) + ax = fig.add_subplot(111, projection="3d") + ax.add_collection3d( + Poly3DCollection( + vertices[faces], + facecolors=(0.65, 0.68, 0.72, 0.16), + edgecolors=(0.35, 0.37, 0.40, 0.08), + linewidths=0.15, + ) + ) + ax.add_collection3d( + Poly3DCollection( + vertices[selected_faces], + facecolors=(1.0, 0.18, 0.05, 0.88), + edgecolors=(0.55, 0.02, 0.0, 1.0), + linewidths=0.8, + ) + ) + self._set_equal_axes(ax, vertices) + ax.view_init(elev=25.0, azim=-45.0) + ax.set_xlabel("X") + ax.set_ylabel("Y") + ax.set_zlabel("Z") + ax.set_title("Selected Support Region") + fig.tight_layout() + fig.savefig(output_path, dpi=self._dpi) + plt.close(fig) + return RenderSupportRegionResult(output_path=output_path) + + def render_xy_alignment_comparison( + self, request: RenderXYComparisonRequest + ) -> RenderXYComparisonResult: + """Render before/after XY projections for PCA yaw alignment.""" + output_path = request.output_path.expanduser().resolve() + output_path.parent.mkdir(parents=True, exist_ok=True) + + before_polygons, before_xy = self._xy_polygons_and_vertices(request.before_mesh) + after_polygons, after_xy = self._xy_polygons_and_vertices(request.after_mesh) + center, view_half = self._xy_view_bounds(before_xy, after_xy) + + fig, axes = plt.subplots(1, 2, figsize=self._figsize) + self._draw_xy_projection( + axes[0], + before_polygons, + before_xy, + "Before PCA yaw", + center, + view_half, + ) + self._draw_xy_projection( + axes[1], + after_polygons, + after_xy, + f"After PCA yaw ({request.angle_degrees:.2f} deg)", + center, + view_half, + ) + fig.tight_layout() + fig.savefig(output_path, dpi=self._dpi) + plt.close(fig) + return RenderXYComparisonResult(output_path=output_path) + + @staticmethod + def _xy_polygons_and_vertices(mesh: Any) -> tuple[Any, Any]: + vertices = np.asarray(mesh.vertices, dtype=float) + faces = np.asarray(mesh.faces, dtype=int) + return vertices[faces][:, :, :2], vertices[:, :2] + + @staticmethod + def _xy_view_bounds(before_xy: Any, after_xy: Any) -> tuple[Any, float]: + values = np.concatenate([before_xy, after_xy], axis=0) + bounds_min = values.min(axis=0) + bounds_max = values.max(axis=0) + center = 0.5 * (bounds_min + bounds_max) + span = np.maximum(bounds_max - bounds_min, 1e-3) + view_half = max(float(span.max()) * 0.65, 0.5) + return center, view_half + + def _draw_xy_projection( + self, + ax: Any, + polygons_xy: Any, + vertices_xy: Any, + title: str, + center: Any, + view_half: float, + ) -> None: + ax.add_collection( + PolyCollection( + polygons_xy, + facecolors=(0.24, 0.50, 0.90, 0.28), + edgecolors=(0.05, 0.16, 0.35, 0.20), + linewidths=0.20, + ) + ) + self._draw_xy_aabb(ax, vertices_xy) + self._add_xy_axes(ax, view_half) + ax.set_xlim(center[0] - view_half, center[0] + view_half) + ax.set_ylim(center[1] - view_half, center[1] + view_half) + ax.set_aspect("equal", adjustable="box") + ax.set_xlabel("X") + ax.set_ylabel("Y") + ax.set_title(title) + ax.grid(True, which="major", linestyle="-", linewidth=0.7, alpha=0.35) + ax.minorticks_on() + ax.grid(True, which="minor", linestyle=":", linewidth=0.45, alpha=0.25) + + @staticmethod + def _draw_xy_aabb(ax: Any, vertices_xy: Any) -> None: + bounds_min = vertices_xy.min(axis=0) + bounds_max = vertices_xy.max(axis=0) + width, height = bounds_max - bounds_min + ax.add_patch( + Rectangle( + (bounds_min[0], bounds_min[1]), + width, + height, + fill=False, + edgecolor="#d62828", + linewidth=1.6, + linestyle="-", + alpha=0.95, + ) + ) + + @staticmethod + def _add_xy_axes(ax: Any, view_half: float) -> None: + arrow_len = max(view_half * 0.35, 0.2) + ax.scatter([0.0], [0.0], color="black", s=22, zorder=8) + ax.text(0.0, 0.0, " Origin", fontsize=9, ha="left", va="bottom") + ax.arrow( + 0.0, + 0.0, + arrow_len, + 0.0, + width=arrow_len * 0.015, + head_width=arrow_len * 0.06, + head_length=arrow_len * 0.08, + color="#d62828", + length_includes_head=True, + zorder=9, + ) + ax.text(arrow_len * 1.08, 0.0, "+X", color="#d62828", fontsize=11) + ax.arrow( + 0.0, + 0.0, + 0.0, + arrow_len, + width=arrow_len * 0.015, + head_width=arrow_len * 0.06, + head_length=arrow_len * 0.08, + color="#2a9d8f", + length_includes_head=True, + zorder=9, + ) + ax.text(0.0, arrow_len * 1.08, "+Y", color="#2a9d8f", fontsize=11) + + @staticmethod + def _set_equal_axes(ax: Any, vertices: Any) -> None: + mins = np.min(vertices, axis=0) + maxs = np.max(vertices, axis=0) + center = (mins + maxs) * 0.5 + radius = max(float(np.max(maxs - mins)) * 0.5, 1e-6) + ax.set_xlim(center[0] - radius, center[0] + radius) + ax.set_ylim(center[1] - radius, center[1] + radius) + ax.set_zlim(center[2] - radius, center[2] + radius) diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/managers/matplotlib_manager/schemas.py b/embodichain/gen_sim/prompt2scene/agent_tools/managers/matplotlib_manager/schemas.py new file mode 100644 index 000000000..764383f38 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/managers/matplotlib_manager/schemas.py @@ -0,0 +1,101 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +__all__ = [ + "RenderFootprintLayoutRequest", + "RenderFootprintLayoutResult", + "RenderImageComparisonRequest", + "RenderImageComparisonResult", + "RenderSupportRegionRequest", + "RenderSupportRegionResult", + "RenderXYComparisonRequest", + "RenderXYComparisonResult", +] + + +@dataclass(frozen=True) +class RenderFootprintLayoutRequest: + """Request to render labeled top-down object footprints.""" + + object_ids: list[str] + centers: dict[str, Any] + xy_sizes: dict[str, Any] + output_path: Path + title: str = "" + + +@dataclass(frozen=True) +class RenderFootprintLayoutResult: + """Result of rendering a footprint layout.""" + + output_path: Path + + +@dataclass(frozen=True) +class RenderImageComparisonRequest: + """Request to render two labeled images side by side.""" + + first_image_path: Path + second_image_path: Path + output_path: Path + first_label: str = "1: normal" + second_label: str = "2: flipped" + + +@dataclass(frozen=True) +class RenderImageComparisonResult: + """Result of rendering an image comparison.""" + + output_path: Path + + +@dataclass(frozen=True) +class RenderSupportRegionRequest: + """Request to render a mesh with the selected support region highlighted.""" + + mesh: Any + face_indices: list[int] + output_path: Path + + +@dataclass(frozen=True) +class RenderSupportRegionResult: + """Result of rendering the support region.""" + + output_path: Path + + +@dataclass(frozen=True) +class RenderXYComparisonRequest: + """Request to render before/after XY projections for PCA yaw alignment.""" + + before_mesh: Any + after_mesh: Any + angle_degrees: float + output_path: Path + + +@dataclass(frozen=True) +class RenderXYComparisonResult: + """Result of rendering the XY alignment comparison.""" + + output_path: Path diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/managers/metric_scale_manager/__init__.py b/embodichain/gen_sim/prompt2scene/agent_tools/managers/metric_scale_manager/__init__.py new file mode 100644 index 000000000..8eca3510d --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/managers/metric_scale_manager/__init__.py @@ -0,0 +1,37 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from embodichain.gen_sim.prompt2scene.agent_tools.managers.metric_scale_manager.manager import ( + METRIC_SCALE_ENABLED, + MetricScaleManager, +) +from embodichain.gen_sim.prompt2scene.agent_tools.managers.metric_scale_manager.schemas import ( + EstimateMetricScalesRequest, + EstimateMetricScalesResult, + GlobalMetricScaleRequest, + MetricScaleObjectInput, +) + +__all__ = [ + "METRIC_SCALE_ENABLED", + "EstimateMetricScalesRequest", + "EstimateMetricScalesResult", + "GlobalMetricScaleRequest", + "MetricScaleManager", + "MetricScaleObjectInput", +] diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/managers/metric_scale_manager/manager.py b/embodichain/gen_sim/prompt2scene/agent_tools/managers/metric_scale_manager/manager.py new file mode 100644 index 000000000..ce1d47e9a --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/managers/metric_scale_manager/manager.py @@ -0,0 +1,431 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from typing import Any + +import numpy as np + +from embodichain.gen_sim.prompt2scene.agent_tools.managers.geometry_manager import ( + GeometryManager, + LoadMeshRequest, +) +from embodichain.gen_sim.prompt2scene.agent_tools.managers.metric_scale_manager.schemas import ( + EstimateMetricScalesRequest, + EstimateMetricScalesResult, + GlobalMetricScaleRequest, + MetricScaleObjectInput, +) +from embodichain.gen_sim.prompt2scene.utils.io import write_json +from embodichain.gen_sim.prompt2scene.workflows.llm_output import ( + call_structured_json_model_step, +) + +__all__ = ["METRIC_SCALE_ENABLED", "MetricScaleManager"] + +METRIC_SCALE_ENABLED = True + + +class MetricScaleManager: + """Manager for metric scale estimation and scale aggregation.""" + + @staticmethod + def estimate_metric_scales( + request: EstimateMetricScalesRequest, + ) -> EstimateMetricScalesResult: + """Call an LLM and convert bbox-size predictions into scale factors.""" + object_payload = MetricScaleManager.build_object_payload(request.objects) + raw_model_output_path = ( + request.raw_output_path.expanduser().resolve() + if request.raw_output_path is not None + else None + ) + raw_model_output = call_structured_json_model_step( + llm=request.llm, + schema=request.schema, + messages=request.messages, + context=request.context, + step_name=request.step_name, + output_root=None, + attempt_count=0, + raw_output_writer=( + (lambda payload: write_json(raw_model_output_path, payload)) + if raw_model_output_path is not None + else None + ), + ) + object_scales = MetricScaleManager.apply_model_output( + object_payload=object_payload, + raw_model_output=raw_model_output, + method=request.method, + ) + return EstimateMetricScalesResult( + status="ok", + object_scales=object_scales, + object_payload=object_payload, + raw_model_output=raw_model_output, + ) + + @staticmethod + def build_object_payload( + objects: list[MetricScaleObjectInput], + ) -> list[dict[str, Any]]: + """Build object payload with normalized mesh bbox measurements.""" + geom = GeometryManager() + payload: list[dict[str, Any]] = [] + for obj in objects: + mesh = geom.load_mesh(LoadMeshRequest(mesh_path=obj.mesh_path)).mesh + normalized_bbox_size_m = GeometryManager.mesh_aabb_size(mesh) + payload.append( + { + "object_id": obj.object_id, + "object_name": obj.object_name, + "object_description": obj.object_description, + "normalized_bbox_size_m": normalized_bbox_size_m.tolist(), + "normalized_bbox_ratio": GeometryManager.bbox_ratio( + normalized_bbox_size_m + ).tolist(), + } + ) + return payload + + @staticmethod + def object_prompt_payload( + objects: list[MetricScaleObjectInput], + ) -> list[dict[str, str]]: + """Return the lightweight object payload intended for LLM prompts.""" + return [ + { + "object_id": obj.object_id, + "object_name": obj.object_name, + "object_description": obj.object_description, + } + for obj in objects + ] + + @staticmethod + def apply_model_output( + *, + object_payload: list[dict[str, Any]], + raw_model_output: dict[str, Any], + method: str, + ) -> list[dict[str, Any]]: + """Convert model bbox predictions into per-object metric-scale records.""" + model_by_id = { + str(item.get("object_id", "")): item + for item in raw_model_output.get("object_scales", []) + if isinstance(item, dict) + } + estimates: list[dict[str, Any]] = [] + for payload in object_payload: + object_id = str(payload.get("object_id", "")) + model_item = model_by_id.get(object_id) + if model_item is None: + estimates.append( + MetricScaleManager.failure( + object_id=object_id, + reason="missing_object_scale_from_model", + method=method, + ) + ) + continue + estimates.append( + MetricScaleManager.select_candidate( + object_id=object_id, + object_name=str(payload.get("object_name", "")), + object_description=str(payload.get("object_description", "")), + bbox_dims_cm=model_item.get("bbox_dims_cm", []), + confidence=float(model_item.get("confidence", 0.0)), + reason=str(model_item.get("reason", "")), + normalized_bbox_size_m=np.asarray( + payload["normalized_bbox_size_m"], + dtype=np.float64, + ), + method=method, + ) + ) + return estimates + + @staticmethod + def apply_to_objects( + *, + objects: list[dict[str, Any]], + object_scales: list[dict[str, Any]], + ) -> None: + """Attach metric-scale records to object dictionaries by object id.""" + scale_by_id = {str(item.get("object_id", "")): item for item in object_scales} + for obj in objects: + object_id = str(obj.get("id", "")) + if object_id in scale_by_id: + obj["metric_scale"] = scale_by_id[object_id] + + @staticmethod + def select_candidate( + *, + object_id: str, + object_name: str, + object_description: str, + bbox_dims_cm: Any, + confidence: float, + reason: str, + normalized_bbox_size_m: np.ndarray, + method: str, + ) -> dict[str, Any]: + """Select a scale factor from predicted real-world bbox dimensions.""" + try: + selected = MetricScaleManager.compute_from_bbox_dims( + bbox_dims_cm=bbox_dims_cm, + confidence=confidence, + reason=reason, + normalized_bbox_size_m=normalized_bbox_size_m, + ) + except (TypeError, ValueError): + return MetricScaleManager.failure( + object_id=object_id, + reason="invalid_bbox_dims_cm", + method=method, + ) + normalized_bbox_size_cm = ( + np.asarray(normalized_bbox_size_m, dtype=np.float64) * 100.0 + ) + return { + "status": "ok", + "method": method, + "object_id": object_id, + "object_name": object_name, + "object_description": object_description, + "normalized_bbox_size_m": normalized_bbox_size_m.tolist(), + "normalized_bbox_size_cm": normalized_bbox_size_cm.tolist(), + "normalized_bbox_ratio": GeometryManager.bbox_ratio( + normalized_bbox_size_m + ).tolist(), + "bbox_dims_cm": selected["bbox_dims_cm"], + "axis_match": selected["axis_match"], + "scale_factor": selected["scale_factor"], + "confidence": selected["confidence"], + "reason": selected["reason"], + "unit_note": "scale_factor is not baked into this GLB.", + } + + @staticmethod + def compute_from_bbox_dims( + *, + bbox_dims_cm: Any, + confidence: float, + reason: str, + normalized_bbox_size_m: np.ndarray, + ) -> dict[str, Any]: + """Compute one scale candidate from model-predicted bbox dimensions.""" + dims_cm = np.asarray( + [float(value) for value in bbox_dims_cm], + dtype=np.float64, + ) + if dims_cm.shape != (3,) or np.any(dims_cm <= 0.0): + raise ValueError("bbox_dims_cm must contain three positive values.") + normalized_bbox_size_cm = ( + np.asarray(normalized_bbox_size_m, dtype=np.float64) * 100.0 + ) + axis_match = GeometryManager.best_axis_bbox_scale_match( + source_size_cm=normalized_bbox_size_cm, + target_size_cm=dims_cm, + ) + return { + "bbox_dims_cm": dims_cm.tolist(), + "axis_match": axis_match, + "scale_factor": float(axis_match["scale_factor"]), + "confidence": confidence, + "reason": reason, + } + + @staticmethod + def failure( + *, + object_id: str, + reason: str, + method: str, + ) -> dict[str, Any]: + """Build a failed per-object metric-scale record.""" + return { + "status": "failed", + "method": method, + "object_id": object_id, + "scale_factor": 1.0, + "reason": reason, + } + + @staticmethod + def set_for_all_objects( + *, + objects: list[dict[str, Any]], + status: str, + reason: str, + method: str, + ) -> None: + """Attach the same fallback metric-scale status to all objects.""" + for obj in objects: + obj["metric_scale"] = { + "status": status, + "method": method, + "object_id": str(obj.get("id", "")), + "scale_factor": 1.0, + "reason": reason, + } + + @staticmethod + def compute_global_from_object_scenes( + request: GlobalMetricScaleRequest, + ) -> dict[str, Any]: + """Aggregate object metric scales into one global scale for a scene layout.""" + if not METRIC_SCALE_ENABLED: + return { + "status": "disabled", + "method": "metric_scale_disabled", + "scale_factor": 1.0, + "object_count": len(request.objects), + "used_count": 0, + "skipped_count": len(request.objects), + "used": [], + "skipped": [ + {"id": str(item.get("id", "")), "reason": "metric_scale_disabled"} + for item in request.objects + ], + "unit_note": ( + "Metric scale is disabled; aligned GLBs keep simready " + "normalized size." + ), + } + + used: list[dict[str, Any]] = [] + skipped: list[dict[str, Any]] = [] + object_by_id = {str(item.get("id", "")): item for item in request.objects} + for object_id, scene in request.object_scenes: + item = object_by_id.get(object_id) + if item is None: + skipped.append({"id": object_id, "reason": "missing_object_record"}) + continue + metric_scale = item.get("metric_scale") + if not isinstance(metric_scale, dict): + skipped.append({"id": object_id, "reason": "missing_metric_scale"}) + continue + if metric_scale.get("status") != "ok": + skipped.append( + { + "id": object_id, + "reason": str(metric_scale.get("status") or "not_ok"), + } + ) + continue + + scale_factor_simready = float(metric_scale.get("scale_factor", 1.0)) + if not np.isfinite(scale_factor_simready) or scale_factor_simready <= 0.0: + skipped.append( + {"id": object_id, "reason": "invalid_simready_scale_factor"} + ) + continue + try: + simready_size_m = np.asarray( + [float(v) for v in metric_scale.get("normalized_bbox_size_m", [])], + dtype=np.float64, + ) + except (TypeError, ValueError): + skipped.append( + {"id": object_id, "reason": "invalid_normalized_bbox_size_m"} + ) + continue + if simready_size_m.shape != (3,) or np.any(simready_size_m <= 0.0): + skipped.append( + {"id": object_id, "reason": "invalid_normalized_bbox_size_m"} + ) + continue + + current_bounds = np.asarray(GeometryManager.scene_to_mesh(scene).bounds) + current_size_m = current_bounds[1] - current_bounds[0] + if current_size_m.shape != (3,) or np.any(current_size_m <= 0.0): + skipped.append({"id": object_id, "reason": "invalid_current_scene_aabb"}) + continue + + geo_ratio = np.sort(current_size_m) / np.sort(simready_size_m) + geo_scale = float(np.median(geo_ratio)) + if not np.isfinite(geo_scale) or geo_scale <= 0.0: + skipped.append({"id": object_id, "reason": "non_positive_geo_scale"}) + continue + + effective_scale = scale_factor_simready / geo_scale + if not np.isfinite(effective_scale) or effective_scale <= 0.0: + skipped.append( + {"id": object_id, "reason": "non_positive_effective_scale"} + ) + continue + + used.append( + { + "id": object_id, + "effective_scale": effective_scale, + "scale_factor_simready": scale_factor_simready, + "geo_scale": geo_scale, + "simready_bbox_size_m": simready_size_m.tolist(), + "simready_bbox_size_cm": (simready_size_m * 100.0).tolist(), + "current_scene_bbox_size_m": current_size_m.tolist(), + "current_scene_bbox_size_cm": (current_size_m * 100.0).tolist(), + "target_bbox_dims_cm": metric_scale.get("bbox_dims_cm"), + "confidence": metric_scale.get("confidence"), + } + ) + + if not used: + return { + "status": "fallback", + "method": "simready_reference_geo_ratio_mean_with_clamp", + "scale_factor": 1.0, + "raw_scale_factor": 1.0, + "was_clamped": False, + "clamp": {"min": request.min_scale, "max": request.max_scale}, + "object_count": len(request.objects), + "used_count": 0, + "skipped_count": len(skipped), + "used": [], + "skipped": skipped, + "unit_note": ( + "No valid metric scale was available; image clutter keeps the " + "SAM3D layout scale without an additional metric scale." + ), + } + + raw_scale_factor = float(np.mean([item["effective_scale"] for item in used])) + scale_factor = float( + np.clip(raw_scale_factor, request.min_scale, request.max_scale) + ) + return { + "status": "ok", + "method": "simready_reference_geo_ratio_mean_with_clamp", + "scale_factor": scale_factor, + "raw_scale_factor": raw_scale_factor, + "was_clamped": bool(scale_factor != raw_scale_factor), + "clamp": {"min": request.min_scale, "max": request.max_scale}, + "object_count": len(request.objects), + "used_count": len(used), + "skipped_count": len(skipped), + "used": used, + "skipped": skipped, + "unit_note": ( + "Global scale derived from scene-level VLM per-object scale_factor " + "divided by the geometric scale ratio between simready normalized " + "bbox and current aligned scene bbox (sorted, permutation-invariant). " + f"Aggregated via mean across objects, clamped to " + f"[{request.min_scale:.2f}, {request.max_scale:.2f}]." + ), + } diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/managers/metric_scale_manager/schemas.py b/embodichain/gen_sim/prompt2scene/agent_tools/managers/metric_scale_manager/schemas.py new file mode 100644 index 000000000..dd2de3437 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/managers/metric_scale_manager/schemas.py @@ -0,0 +1,73 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +__all__ = [ + "EstimateMetricScalesRequest", + "EstimateMetricScalesResult", + "GlobalMetricScaleRequest", + "MetricScaleObjectInput", +] + + +@dataclass(frozen=True) +class MetricScaleObjectInput: + """Object input for metric-scale estimation.""" + + object_id: str + object_name: str + object_description: str + mesh_path: Path + + +@dataclass(frozen=True) +class EstimateMetricScalesRequest: + """Request to estimate metric scale for a set of normalized objects.""" + + objects: list[MetricScaleObjectInput] + messages: list[dict[str, Any]] + schema: dict[str, Any] + llm: Any + context: str + method: str + step_name: str = "metric_scale" + raw_output_path: Path | None = None + + +@dataclass(frozen=True) +class EstimateMetricScalesResult: + """Result of estimating metric scale for normalized objects.""" + + status: str + object_scales: list[dict[str, Any]] + object_payload: list[dict[str, Any]] + raw_model_output: dict[str, Any] | None = None + reason: str = "" + + +@dataclass(frozen=True) +class GlobalMetricScaleRequest: + """Request to aggregate per-object metric scales into one scene scale.""" + + objects: list[dict[str, Any]] + object_scenes: list[tuple[str, Any]] + min_scale: float = 0.10 + max_scale: float = 10.00 diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/managers/optimization_manager/__init__.py b/embodichain/gen_sim/prompt2scene/agent_tools/managers/optimization_manager/__init__.py new file mode 100644 index 000000000..b61756bf0 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/managers/optimization_manager/__init__.py @@ -0,0 +1,37 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from embodichain.gen_sim.prompt2scene.agent_tools.managers.optimization_manager.manager import ( + _center_xy_aabb_layout, + _footprint_layout_diagnostics, + _object_scenes_xy_aabb_manifest, + _settle_and_pack_object_footprints, + _xy_aabb_overlap, + _xy_union_area, + _xy_union_bounds, +) + +__all__ = [ + "_center_xy_aabb_layout", + "_footprint_layout_diagnostics", + "_object_scenes_xy_aabb_manifest", + "_settle_and_pack_object_footprints", + "_xy_aabb_overlap", + "_xy_union_area", + "_xy_union_bounds", +] diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/managers/optimization_manager/manager.py b/embodichain/gen_sim/prompt2scene/agent_tools/managers/optimization_manager/manager.py new file mode 100644 index 000000000..d7ed13484 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/managers/optimization_manager/manager.py @@ -0,0 +1,633 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + + +from __future__ import annotations + +import tempfile +import traceback +from pathlib import Path +from typing import Any + +import numpy as np + +from embodichain.gen_sim.prompt2scene.agent_tools.managers.simulation_manager import ( + SimulationManager, +) +from embodichain.gen_sim.prompt2scene.agent_tools.managers.simulation_manager.schemas import ( + GravityDropRequest, +) +from embodichain.gen_sim.prompt2scene.agent_tools.managers.geometry_manager.scene_geometry import ( + _aabb_bottom_to_xy_plane_transform, + _copy_scene_with_transform, + _matrix_from_json, + _scene_to_mesh, + _xy_aabb_center, + _xy_aabb_size, + _z_up_to_glb_y_up_transform, +) +from embodichain.gen_sim.prompt2scene.utils.io import ( + relative_path, +) + +__all__ = [ + "_center_xy_aabb_layout", + "_object_scenes_xy_aabb_manifest", + "_settle_and_pack_object_footprints", + "_xy_aabb_overlap", + "_xy_union_area", + "_xy_union_bounds", +] + +def _object_scenes_xy_aabb_manifest( + *, + object_scenes: list[tuple[str, Any]], + trimesh: Any, + unit_scale: float, + unit: str, +) -> dict[str, Any]: + if not object_scenes: + return { + "status": "empty", + "unit": unit, + "object_count": 0, + } + bounds = [ + np.asarray(_scene_to_mesh(scene, trimesh=trimesh).bounds, dtype=np.float64) + for _, scene in object_scenes + ] + union_bounds = np.vstack( + [ + np.vstack([item[0] for item in bounds]).min(axis=0), + np.vstack([item[1] for item in bounds]).max(axis=0), + ] + ) + min_xy = union_bounds[0, :2] * unit_scale + max_xy = union_bounds[1, :2] * unit_scale + size_xy = max_xy - min_xy + center_xy = 0.5 * (min_xy + max_xy) + return { + "status": "ok", + "unit": unit, + "object_count": len(object_scenes), + "min_xy": min_xy.tolist(), + "max_xy": max_xy.tolist(), + "center_xy": center_xy.tolist(), + "size_xy": size_xy.tolist(), + "area": float(size_xy[0] * size_xy[1]), + } + + + +def _settle_and_pack_object_footprints( + *, + object_scenes: list[tuple[str, Any]], + output_dir: Path, + output_root: Path, + trimesh: Any, +) -> dict[str, Any]: + sim = SimulationManager(headless=True, sim_device="cpu") + footprint_items: list[dict[str, Any]] = [] + settled_entries: list[dict[str, Any]] = [] + output_axis_transform = _z_up_to_glb_y_up_transform() + output_to_internal_transform = np.linalg.inv(output_axis_transform) + + with tempfile.TemporaryDirectory(prefix="p2s_footprint_drop_") as tmp_dir: + tmp_path = Path(tmp_dir) + for object_id, scene in object_scenes: + mesh = _scene_to_mesh(scene, trimesh=trimesh) + mesh_bounds = np.asarray(mesh.bounds, dtype=np.float64) + mesh_z_height = max(float(mesh_bounds[1][2] - mesh_bounds[0][2]), 0.0) + bottom_to_xy_plane_transform = _aabb_bottom_to_xy_plane_transform( + mesh_bounds + ) + normalized_scene = _copy_scene_with_transform( + scene, + bottom_to_xy_plane_transform, + ) + normalized_output_scene = _copy_scene_with_transform( + normalized_scene, + output_axis_transform, + ) + pre_gravity_path = tmp_path / f"{object_id}_pre_gravity.glb" + normalized_output_scene.export(pre_gravity_path) + gravity_initial_height = mesh_z_height * 0.1 + + gravity_status = "ok" + gravity_transform = np.eye(4, dtype=np.float64) + gravity_reason = "" + try: + gravity_result = sim.run_gravity_simulation( + GravityDropRequest( + glb_path=pre_gravity_path, + max_convex_hull_num=32, + initial_height=gravity_initial_height, + ) + ) + gravity_transform = _matrix_from_json( + gravity_result.final_pose, + name=f"{object_id}.gravity_final_pose", + ) + except Exception: + gravity_status = "failed" + gravity_reason = traceback.format_exc() + + settled_origin_scene = _copy_scene_with_transform( + normalized_scene, + gravity_transform, + ) + settled_mesh = _scene_to_mesh(settled_origin_scene, trimesh=trimesh) + settled_bounds = np.asarray(settled_mesh.bounds, dtype=np.float64) + settled_xy_center = _xy_aabb_center(settled_bounds) + settled_xy_size = _xy_aabb_size(settled_bounds) + settled_entries.append( + { + "id": object_id, + "scene": scene, + "bottom_to_xy_plane_transform": bottom_to_xy_plane_transform, + "mesh_z_height": mesh_z_height, + "gravity_initial_height": gravity_initial_height, + "gravity_transform": gravity_transform, + "settled_bounds": settled_bounds, + "settled_xy_center": settled_xy_center, + "settled_xy_size": settled_xy_size, + "gravity_status": gravity_status, + "gravity_reason": gravity_reason, + } + ) + + layout_result = _optimize_xy_aabb_footprint_layout( + object_ids=[str(entry["id"]) for entry in settled_entries], + xy_sizes={ + str(entry["id"]): np.asarray(entry["settled_xy_size"], dtype=np.float64) + for entry in settled_entries + }, + current_centers={ + str(entry["id"]): _xy_aabb_center( + _scene_to_mesh(entry["scene"], trimesh=trimesh).bounds + ) + for entry in settled_entries + }, + ) + target_centers = layout_result["centers"] + + packed_object_scenes: list[tuple[str, Any]] = [] + object_layout_transforms: dict[str, np.ndarray] = {} + for entry in settled_entries: + object_id = str(entry["id"]) + settled_bounds = np.asarray(entry["settled_bounds"], dtype=np.float64) + target_xy = target_centers[object_id] + placement_transform = np.eye(4, dtype=np.float64) + placement_transform[:3, 3] = [ + float(target_xy[0] - entry["settled_xy_center"][0]), + float(target_xy[1] - entry["settled_xy_center"][1]), + -float(settled_bounds[0][2]), + ] + object_transform = ( + placement_transform + @ entry["gravity_transform"] + @ entry["bottom_to_xy_plane_transform"] + ) + packed_scene = _copy_scene_with_transform(entry["scene"], object_transform) + packed_object_scenes.append((object_id, packed_scene)) + object_layout_transforms[object_id] = object_transform + + packed_bounds = np.asarray( + _scene_to_mesh(packed_scene, trimesh=trimesh).bounds, + dtype=np.float64, + ) + footprint_items.append( + { + "id": object_id, + "gravity_status": entry["gravity_status"], + "gravity_reason": entry["gravity_reason"], + "bottom_to_xy_plane_transform": entry[ + "bottom_to_xy_plane_transform" + ].tolist(), + "mesh_z_height": entry["mesh_z_height"], + "gravity_initial_height": entry["gravity_initial_height"], + "gravity_transform": entry["gravity_transform"].tolist(), + "placement_transform": placement_transform.tolist(), + "object_layout_transform": object_transform.tolist(), + "settled_xy_size": entry["settled_xy_size"].tolist(), + "target_xy_center": target_xy.tolist(), + "packed_bounds": packed_bounds.tolist(), + } + ) + + manifest = { + "status": "ok", + "method": "per_object_gravity_then_geometry_knn_2d_aabb_relaxation", + "output_dir": relative_path(str(output_dir), output_root), + "internal_up_axis": [0.0, 0.0, 1.0], + "gravity_glb_up_axis": [0.0, 1.0, 0.0], + "internal_to_gravity_glb_transform": output_axis_transform.tolist(), + "gravity_glb_to_internal_transform": output_to_internal_transform.tolist(), + "layout_optimization": layout_result["metadata"], + "items": footprint_items, + } + return { + "object_scenes": packed_object_scenes, + "object_layout_transforms": object_layout_transforms, + "manifest": manifest, + } + + + +def _optimize_xy_aabb_footprint_layout( + *, + object_ids: list[str], + xy_sizes: dict[str, np.ndarray], + current_centers: dict[str, np.ndarray], + padding_ratio: float = 0.08, +) -> dict[str, Any]: + if not object_ids: + return { + "centers": {}, + "metadata": { + "method": "geometry_knn_2d_aabb_relaxation", + "iterations": 0, + "confidence_score": 1.0, + }, + } + + max_extent = max( + float(max(xy_sizes[object_id][0], xy_sizes[object_id][1])) + for object_id in object_ids + ) + padding = max(max_extent * padding_ratio, 1e-3) + max_iterations = 300 + overlap_strength = 1.0 + neighbor_strength = 0.04 + compactness_strength = 0.01 + target_expansion_ratio = 1.2 + knn_k = min(3, max(len(object_ids) - 1, 0)) + centers = { + object_id: np.asarray( + current_centers.get(object_id, np.zeros(2, dtype=np.float64)), + dtype=np.float64, + ).copy() + for object_id in object_ids + } + centers = _center_xy_aabb_layout( + centers=centers, + xy_sizes=xy_sizes, + ) + initial_centers = { + object_id: center.copy() + for object_id, center in centers.items() + } + initial_union_bounds = _xy_union_bounds( + centers=initial_centers, + xy_sizes=xy_sizes, + ) + neighbor_edges = _knn_neighbor_edges( + centers=initial_centers, + k=knn_k, + ) + + iterations = 0 + for iteration in range(max_iterations): + iterations = iteration + 1 + max_delta = 0.0 + + for i, object_id in enumerate(object_ids): + for other_id in object_ids[i + 1 :]: + overlap = _xy_aabb_overlap( + center_a=centers[object_id], + size_a=xy_sizes[object_id], + center_b=centers[other_id], + size_b=xy_sizes[other_id], + padding=padding, + ) + if overlap is None: + continue + overlap_x, overlap_y = overlap + if overlap_x <= overlap_y: + axis = 0 + sign = ( + -1.0 + if centers[object_id][0] <= centers[other_id][0] + else 1.0 + ) + amount = overlap_x + else: + axis = 1 + sign = ( + -1.0 + if centers[object_id][1] <= centers[other_id][1] + else 1.0 + ) + amount = overlap_y + shift = 0.5 * (amount + 1e-6) * overlap_strength + centers[object_id][axis] += sign * shift + centers[other_id][axis] -= sign * shift + max_delta = max(max_delta, shift) + + for edge in neighbor_edges: + object_id = edge["object"] + neighbor_id = edge["neighbor"] + initial_delta = np.asarray(edge["initial_delta"], dtype=np.float64) + error = (centers[object_id] - centers[neighbor_id]) - initial_delta + correction = 0.5 * neighbor_strength * error + centers[object_id] -= correction + centers[neighbor_id] += correction + max_delta = max(max_delta, float(np.linalg.norm(correction))) + + max_delta = max( + max_delta, + _apply_compactness_pull( + centers=centers, + xy_sizes=xy_sizes, + initial_union_bounds=initial_union_bounds, + target_expansion_ratio=target_expansion_ratio, + strength=compactness_strength, + ), + ) + + centers = _center_xy_aabb_layout( + centers=centers, + xy_sizes=xy_sizes, + ) + if iteration >= 20 and max_delta < 1e-5: + break + + diagnostics = _footprint_layout_diagnostics( + object_ids=object_ids, + centers=centers, + initial_centers=initial_centers, + xy_sizes=xy_sizes, + padding=padding, + initial_union_bounds=initial_union_bounds, + ) + metadata = { + "method": "geometry_knn_2d_aabb_relaxation", + "relation_usage": "disabled", + "iterations": iterations, + "padding": padding, + "padding_ratio": padding_ratio, + "max_iterations": max_iterations, + "overlap_strength": overlap_strength, + "neighbor_strength": neighbor_strength, + "compactness_strength": compactness_strength, + "target_expansion_ratio": target_expansion_ratio, + "knn_k": knn_k, + "neighbor_edges": neighbor_edges, + "final_centers": { + object_id: centers[object_id].tolist() + for object_id in object_ids + }, + **diagnostics, + } + return {"centers": centers, "metadata": metadata} + + + +def _knn_neighbor_edges( + *, + centers: dict[str, np.ndarray], + k: int, +) -> list[dict[str, Any]]: + if k <= 0 or len(centers) < 2: + return [] + object_ids = sorted(centers) + edges: list[dict[str, Any]] = [] + seen: set[tuple[str, str]] = set() + for object_id in object_ids: + distances = [] + for other_id in object_ids: + if other_id == object_id: + continue + distance = float(np.linalg.norm(centers[object_id] - centers[other_id])) + distances.append((distance, other_id)) + for _, neighbor_id in sorted(distances)[:k]: + edge_key = tuple(sorted((object_id, neighbor_id))) + if edge_key in seen: + continue + seen.add(edge_key) + edges.append( + { + "object": object_id, + "neighbor": neighbor_id, + "initial_delta": ( + centers[object_id] - centers[neighbor_id] + ).tolist(), + } + ) + return edges + + + +def _apply_compactness_pull( + *, + centers: dict[str, np.ndarray], + xy_sizes: dict[str, np.ndarray], + initial_union_bounds: np.ndarray, + target_expansion_ratio: float, + strength: float, +) -> float: + current_bounds = _xy_union_bounds(centers=centers, xy_sizes=xy_sizes) + expansion_ratio = _xy_union_area(current_bounds) / max( + _xy_union_area(initial_union_bounds), + 1.0e-12, + ) + if expansion_ratio <= target_expansion_ratio: + return 0.0 + excess = min(expansion_ratio / target_expansion_ratio - 1.0, 1.0) + union_center = 0.5 * (current_bounds[0] + current_bounds[1]) + factor = strength * excess + max_delta = 0.0 + for object_id, center in centers.items(): + delta = factor * (union_center - center) + centers[object_id] = center + delta + max_delta = max(max_delta, float(np.linalg.norm(delta))) + return max_delta + + + +def _footprint_layout_diagnostics( + *, + object_ids: list[str], + centers: dict[str, np.ndarray], + initial_centers: dict[str, np.ndarray], + xy_sizes: dict[str, np.ndarray], + padding: float, + initial_union_bounds: np.ndarray, +) -> dict[str, Any]: + remaining_overlaps = _remaining_xy_overlaps( + object_ids=object_ids, + centers=centers, + xy_sizes=xy_sizes, + padding=padding, + ) + displacements = [ + float(np.linalg.norm(centers[object_id] - initial_centers[object_id])) + for object_id in object_ids + ] + current_union_bounds = _xy_union_bounds(centers=centers, xy_sizes=xy_sizes) + expansion_ratio = _xy_union_area(current_union_bounds) / max( + _xy_union_area(initial_union_bounds), + 1.0e-12, + ) + average_displacement = float(np.mean(displacements)) if displacements else 0.0 + max_displacement = float(np.max(displacements)) if displacements else 0.0 + confidence_score = _footprint_confidence_score( + remaining_overlap_count=len(remaining_overlaps), + average_displacement=average_displacement, + max_extent=max( + float(max(xy_sizes[object_id][0], xy_sizes[object_id][1])) + for object_id in object_ids + ) + if object_ids + else 1.0, + expansion_ratio=expansion_ratio, + ) + return { + "remaining_overlaps": remaining_overlaps, + "average_displacement": average_displacement, + "max_displacement": max_displacement, + "union_aabb_expansion_ratio": expansion_ratio, + "confidence_score": confidence_score, + } + + + +def _remaining_xy_overlaps( + *, + object_ids: list[str], + centers: dict[str, np.ndarray], + xy_sizes: dict[str, np.ndarray], + padding: float, +) -> list[dict[str, Any]]: + overlaps: list[dict[str, Any]] = [] + for index, object_id in enumerate(object_ids): + for other_id in object_ids[index + 1 :]: + overlap = _xy_aabb_overlap( + center_a=centers[object_id], + size_a=xy_sizes[object_id], + center_b=centers[other_id], + size_b=xy_sizes[other_id], + padding=padding, + ) + if overlap is None: + continue + overlaps.append( + { + "object": object_id, + "other": other_id, + "overlap_x": overlap[0], + "overlap_y": overlap[1], + } + ) + return overlaps + + + +def _footprint_confidence_score( + *, + remaining_overlap_count: int, + average_displacement: float, + max_extent: float, + expansion_ratio: float, +) -> float: + displacement_scale = max(max_extent, 1.0e-6) + overlap_penalty = min(0.35 * remaining_overlap_count, 0.7) + displacement_penalty = min(0.1 * average_displacement / displacement_scale, 0.2) + expansion_penalty = min(max(expansion_ratio - 1.2, 0.0) * 0.25, 0.2) + return float( + np.clip( + 1.0 + - overlap_penalty + - displacement_penalty + - expansion_penalty, + 0.0, + 1.0, + ) + ) + + + +def _center_xy_aabb_layout( + *, + centers: dict[str, np.ndarray], + xy_sizes: dict[str, np.ndarray], +) -> dict[str, np.ndarray]: + if not centers: + return centers + bounds_min = [] + bounds_max = [] + for object_id, center in centers.items(): + half_size = 0.5 * np.asarray(xy_sizes[object_id], dtype=np.float64) + bounds_min.append(center - half_size) + bounds_max.append(center + half_size) + clutter_center = 0.5 * ( + np.vstack(bounds_min).min(axis=0) + + np.vstack(bounds_max).max(axis=0) + ) + return { + object_id: np.asarray(center, dtype=np.float64) - clutter_center + for object_id, center in centers.items() + } + + + +def _xy_union_bounds( + *, + centers: dict[str, np.ndarray], + xy_sizes: dict[str, np.ndarray], +) -> np.ndarray: + if not centers: + return np.zeros((2, 2), dtype=np.float64) + bounds_min = [] + bounds_max = [] + for object_id, center in centers.items(): + half_size = 0.5 * np.asarray(xy_sizes[object_id], dtype=np.float64) + bounds_min.append(np.asarray(center, dtype=np.float64) - half_size) + bounds_max.append(np.asarray(center, dtype=np.float64) + half_size) + return np.vstack( + [ + np.vstack(bounds_min).min(axis=0), + np.vstack(bounds_max).max(axis=0), + ] + ) + + + +def _xy_union_area(bounds: np.ndarray) -> float: + bounds = np.asarray(bounds, dtype=np.float64) + size = np.maximum(bounds[1] - bounds[0], 1.0e-9) + return float(size[0] * size[1]) + + + +def _xy_aabb_overlap( + *, + center_a: np.ndarray, + size_a: np.ndarray, + center_b: np.ndarray, + size_b: np.ndarray, + padding: float, +) -> tuple[float, float] | None: + half_a = 0.5 * np.asarray(size_a, dtype=np.float64) + half_b = 0.5 * np.asarray(size_b, dtype=np.float64) + delta = np.abs( + np.asarray(center_b, dtype=np.float64) + - np.asarray(center_a, dtype=np.float64) + ) + overlap = half_a + half_b + padding - delta + if float(overlap[0]) <= 0.0 or float(overlap[1]) <= 0.0: + return None + return float(overlap[0]), float(overlap[1]) diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/managers/simready_manager/__init__.py b/embodichain/gen_sim/prompt2scene/agent_tools/managers/simready_manager/__init__.py new file mode 100644 index 000000000..12ebfd690 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/managers/simready_manager/__init__.py @@ -0,0 +1,35 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from embodichain.gen_sim.prompt2scene.agent_tools.managers.simready_manager.manager import ( + SimreadyManager, +) +from embodichain.gen_sim.prompt2scene.agent_tools.managers.simready_manager.schemas import ( + MakeAssetSimreadyRequest, + MakeAssetSimreadyResult, + MakeTableSimreadyRequest, + MakeTableSimreadyResult, +) + +__all__ = [ + "MakeAssetSimreadyRequest", + "MakeAssetSimreadyResult", + "MakeTableSimreadyRequest", + "MakeTableSimreadyResult", + "SimreadyManager", +] diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/managers/simready_manager/manager.py b/embodichain/gen_sim/prompt2scene/agent_tools/managers/simready_manager/manager.py new file mode 100644 index 000000000..6f92e1f84 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/managers/simready_manager/manager.py @@ -0,0 +1,396 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from typing import Any + +import numpy as np + +from embodichain.gen_sim.prompt2scene.agent_tools.managers.geometry_manager.manager import ( + DEFAULT_INPUT_UP_AXIS, + DEFAULT_UP_AXIS, + GeometryManager, +) +from embodichain.gen_sim.prompt2scene.agent_tools.managers.geometry_manager.schemas import ( + AlignToAxisRequest, + CenterMeshRequest, + ConvertUpAxisRequest, + DetectTabletopRequest, + ExportMeshRequest, + LoadMeshRequest, + NormalizeRequest, + PlaceAbovePlaneRequest, +) +from embodichain.gen_sim.prompt2scene.agent_tools.managers.matplotlib_manager.manager import ( + MatplotlibManager, +) +from embodichain.gen_sim.prompt2scene.agent_tools.managers.matplotlib_manager.schemas import ( + RenderSupportRegionRequest, +) +from embodichain.gen_sim.prompt2scene.agent_tools.managers.simready_manager.schemas import ( + MakeAssetSimreadyRequest, + MakeAssetSimreadyResult, + MakeTableSimreadyRequest, + MakeTableSimreadyResult, +) +from embodichain.gen_sim.prompt2scene.agent_tools.managers.simulation_manager.manager import ( + SimulationManager, +) +from embodichain.gen_sim.prompt2scene.agent_tools.managers.simulation_manager.schemas import ( + GravityDropRequest, +) + + +class SimreadyManager: + """Prepare generated GLB assets for simulation placement.""" + + def __init__( + self, + *, + geometry_manager: GeometryManager | None = None, + simulation_manager: SimulationManager | None = None, + matplotlib_manager: MatplotlibManager | None = None, + ) -> None: + self.geometry_manager = geometry_manager or GeometryManager() + self.simulation_manager = simulation_manager or SimulationManager() + self.matplotlib_manager = matplotlib_manager or MatplotlibManager() + + def make_asset_simready( + self, + request: MakeAssetSimreadyRequest, + ) -> MakeAssetSimreadyResult: + input_path = request.input_path.expanduser().resolve() + output_path = request.output_path.expanduser().resolve() + if output_path.suffix.lower() != ".glb": + raise ValueError("Sim-ready asset output_path must be a .glb file.") + output_path.parent.mkdir(parents=True, exist_ok=True) + + input_up_axis = _request_axis(request.input_up_axis, DEFAULT_INPUT_UP_AXIS) + raw_to_simready = np.eye(4, dtype=np.float64) + geom = self.geometry_manager + sim = self.simulation_manager + + mesh = geom.load_mesh(LoadMeshRequest(mesh_path=input_path)).mesh + + transform = _axis_conversion_transform(input_up_axis, DEFAULT_UP_AXIS) + raw_to_simready = transform @ raw_to_simready + mesh = geom.convert_up_axis( + ConvertUpAxisRequest( + mesh=mesh, + input_up_axis=input_up_axis, + output_up_axis=DEFAULT_UP_AXIS, + ) + ).mesh + + center_result = geom.center_by_bbox(CenterMeshRequest(mesh=mesh)) + mesh = center_result.mesh + transform = _translation_transform(-np.asarray(center_result.bbox_center)) + raw_to_simready = transform @ raw_to_simready + + transform = _place_above_plane_transform(mesh, request.ground_clearance) + raw_to_simready = transform @ raw_to_simready + mesh = geom.place_above_plane( + PlaceAbovePlaneRequest(mesh=mesh, clearance=request.ground_clearance) + ).mesh + + pre_gravity_mesh = geom.convert_up_axis( + ConvertUpAxisRequest( + mesh=mesh, + input_up_axis=DEFAULT_UP_AXIS, + output_up_axis=DEFAULT_INPUT_UP_AXIS, + ) + ).mesh + pre_gravity_path = output_path.with_name(f".{output_path.stem}_pre_gravity.glb") + geom.export_mesh( + ExportMeshRequest(mesh=pre_gravity_mesh, output_path=pre_gravity_path) + ) + try: + gravity_result = sim.run_gravity_simulation( + GravityDropRequest(glb_path=pre_gravity_path, max_convex_hull_num=32) + ) + + gravity_transform = _as_transform(gravity_result.final_pose) + settled_mesh = mesh.copy() + settled_mesh.apply_transform(gravity_transform) + raw_to_simready = gravity_transform @ raw_to_simready + transform = _center_aabb_bottom_xy_at_origin_transform(settled_mesh) + settled_mesh.apply_transform(transform) + raw_to_simready = transform @ raw_to_simready + + transform = _center_aabb_bottom_xy_at_origin_transform(settled_mesh) + raw_to_simready = transform @ raw_to_simready + final_mesh = _center_aabb_bottom_xy_at_origin(settled_mesh) + + normalize_result = geom.normalize(NormalizeRequest(mesh=final_mesh)) + final_mesh = normalize_result.mesh + transform = _scale_transform(normalize_result.scale_factor) + raw_to_simready = transform @ raw_to_simready + + transform = _place_above_plane_transform(final_mesh, request.ground_clearance) + raw_to_simready = transform @ raw_to_simready + final_mesh = geom.place_above_plane( + PlaceAbovePlaneRequest( + mesh=final_mesh, + clearance=request.ground_clearance, + ) + ).mesh + + transform = _axis_conversion_transform(DEFAULT_UP_AXIS, DEFAULT_INPUT_UP_AXIS) + raw_to_simready = transform @ raw_to_simready + final_mesh = geom.convert_up_axis( + ConvertUpAxisRequest( + mesh=final_mesh, + input_up_axis=DEFAULT_UP_AXIS, + output_up_axis=DEFAULT_INPUT_UP_AXIS, + ) + ).mesh + + geom.export_mesh(ExportMeshRequest(mesh=final_mesh, output_path=output_path)) + finally: + pre_gravity_path.unlink(missing_ok=True) + + return MakeAssetSimreadyResult( + output_path=output_path, + transform_matrix=raw_to_simready.tolist(), + ) + + def make_table_simready( + self, + request: MakeTableSimreadyRequest, + ) -> MakeTableSimreadyResult: + input_path = request.input_path.expanduser().resolve() + output_path = request.output_path.expanduser().resolve() + if output_path.suffix.lower() != ".glb": + raise ValueError("Sim-ready table output_path must be a .glb file.") + output_path.parent.mkdir(parents=True, exist_ok=True) + + input_up_axis = _request_axis(request.input_up_axis, DEFAULT_INPUT_UP_AXIS) + up_axis = _request_axis(request.up_axis, DEFAULT_UP_AXIS) + raw_to_simready = np.eye(4, dtype=np.float64) + geom = self.geometry_manager + sim = self.simulation_manager + mpl = self.matplotlib_manager + + mesh = geom.load_mesh(LoadMeshRequest(mesh_path=input_path)).mesh + + transform = _axis_conversion_transform(input_up_axis, DEFAULT_UP_AXIS) + raw_to_simready = transform @ raw_to_simready + mesh = geom.convert_up_axis( + ConvertUpAxisRequest( + mesh=mesh, + input_up_axis=input_up_axis, + output_up_axis=DEFAULT_UP_AXIS, + ) + ).mesh + + center_result = geom.center_by_bbox(CenterMeshRequest(mesh=mesh)) + mesh = center_result.mesh + transform = _translation_transform(-np.asarray(center_result.bbox_center)) + raw_to_simready = transform @ raw_to_simready + + detect_result = geom.detect_tabletop(DetectTabletopRequest(mesh=mesh)) + + transform = _axis_conversion_transform(detect_result.oriented_normal, up_axis) + raw_to_simready = transform @ raw_to_simready + mesh = geom.align_to_axis( + AlignToAxisRequest( + mesh=mesh, + source_axis=detect_result.oriented_normal, + target_axis=up_axis, + ) + ).mesh + + transform = _place_above_plane_transform(mesh, request.ground_clearance) + raw_to_simready = transform @ raw_to_simready + mesh = geom.place_above_plane( + PlaceAbovePlaneRequest(mesh=mesh, clearance=request.ground_clearance) + ).mesh + + pre_gravity_mesh = geom.convert_up_axis( + ConvertUpAxisRequest( + mesh=mesh, + input_up_axis=DEFAULT_UP_AXIS, + output_up_axis=DEFAULT_INPUT_UP_AXIS, + ) + ).mesh + pre_gravity_path = output_path.with_name(f".{output_path.stem}_pre_gravity.glb") + geom.export_mesh( + ExportMeshRequest(mesh=pre_gravity_mesh, output_path=pre_gravity_path) + ) + try: + gravity_result = sim.run_gravity_simulation( + GravityDropRequest(glb_path=pre_gravity_path, max_convex_hull_num=16) + ) + + gravity_transform = _as_transform(gravity_result.final_pose) + settled_mesh = mesh.copy() + settled_mesh.apply_transform(gravity_transform) + raw_to_simready = gravity_transform @ raw_to_simready + transform = _center_aabb_bottom_xy_at_origin_transform(settled_mesh) + settled_mesh.apply_transform(transform) + raw_to_simready = transform @ raw_to_simready + + settled_detect = geom.detect_tabletop( + DetectTabletopRequest(mesh=settled_mesh) + ) + + mpl.render_selected_support_region( + RenderSupportRegionRequest( + mesh=settled_mesh, + face_indices=settled_detect.selected.face_indices, + output_path=output_path.with_name( + f"{output_path.stem}_support_region.png" + ), + ) + ) + + transform = _center_aabb_bottom_xy_at_origin_transform(settled_mesh) + raw_to_simready = transform @ raw_to_simready + final_mesh = _center_aabb_bottom_xy_at_origin(settled_mesh) + + normalize_result = geom.normalize(NormalizeRequest(mesh=final_mesh)) + final_mesh = normalize_result.mesh + transform = _scale_transform(normalize_result.scale_factor) + raw_to_simready = transform @ raw_to_simready + + transform = _place_above_plane_transform(final_mesh, request.ground_clearance) + raw_to_simready = transform @ raw_to_simready + final_mesh = geom.place_above_plane( + PlaceAbovePlaneRequest( + mesh=final_mesh, + clearance=request.ground_clearance, + ) + ).mesh + + transform = _axis_conversion_transform(DEFAULT_UP_AXIS, DEFAULT_INPUT_UP_AXIS) + raw_to_simready = transform @ raw_to_simready + final_mesh = geom.convert_up_axis( + ConvertUpAxisRequest( + mesh=final_mesh, + input_up_axis=DEFAULT_UP_AXIS, + output_up_axis=DEFAULT_INPUT_UP_AXIS, + ) + ).mesh + + geom.export_mesh(ExportMeshRequest(mesh=final_mesh, output_path=output_path)) + finally: + pre_gravity_path.unlink(missing_ok=True) + + return MakeTableSimreadyResult( + output_path=output_path, + transform_matrix=raw_to_simready.tolist(), + ) + + +def _request_axis(value: list[float] | None, default: tuple[float, float, float]) -> list[float]: + if value is not None: + return list(value) + return list(default) + + +def _center_aabb_bottom_xy_at_origin(mesh: Any) -> Any: + bounds = mesh.bounds + bottom_center_x = (float(bounds[0][0]) + float(bounds[1][0])) * 0.5 + bottom_center_y = (float(bounds[0][1]) + float(bounds[1][1])) * 0.5 + centered = mesh.copy() + centered.apply_translation([-bottom_center_x, -bottom_center_y, 0.0]) + return centered + + +def _axis_conversion_transform(source_axis: list[float], target_axis: list[float]) -> np.ndarray: + source = _normalize(np.asarray(source_axis, dtype=np.float64)) + target = _normalize(np.asarray(target_axis, dtype=np.float64)) + return _rotation_between_vectors(source, target) + + +def _place_above_plane_transform(mesh: Any, clearance: float) -> np.ndarray: + min_z = float(mesh.bounds[0][2]) + return _translation_transform(np.array([0.0, 0.0, clearance - min_z])) + + +def _center_aabb_bottom_xy_at_origin_transform(mesh: Any) -> np.ndarray: + bounds = mesh.bounds + bottom_center_x = (float(bounds[0][0]) + float(bounds[1][0])) * 0.5 + bottom_center_y = (float(bounds[0][1]) + float(bounds[1][1])) * 0.5 + return _translation_transform(np.array([-bottom_center_x, -bottom_center_y, 0.0])) + + +def _translation_transform(translation: np.ndarray) -> np.ndarray: + transform = np.eye(4, dtype=np.float64) + transform[:3, 3] = translation + return transform + + +def _scale_transform(scale: float) -> np.ndarray: + transform = np.eye(4, dtype=np.float64) + transform[:3, :3] *= float(scale) + return transform + + +def _as_transform(value: Any) -> np.ndarray: + transform = np.asarray(value, dtype=np.float64) + if transform.shape != (4, 4): + raise ValueError("Expected a 4x4 transform matrix.") + return transform + + +def _rotation_between_vectors(source: np.ndarray, target: np.ndarray) -> np.ndarray: + source = _normalize(source) + target = _normalize(target) + dot = float(np.clip(np.dot(source, target), -1.0, 1.0)) + transform = np.eye(4, dtype=np.float64) + if dot > 1.0 - 1e-8: + return transform + if dot < -1.0 + 1e-8: + axis = _orthogonal_axis(source) + rotation = _axis_angle_rotation(axis, np.pi) + else: + axis = _normalize(np.cross(source, target)) + angle = float(np.arccos(dot)) + rotation = _axis_angle_rotation(axis, angle) + transform[:3, :3] = rotation + return transform + + +def _axis_angle_rotation(axis: np.ndarray, angle: float) -> np.ndarray: + axis = _normalize(axis) + x, y, z = axis + c = float(np.cos(angle)) + s = float(np.sin(angle)) + one_c = 1.0 - c + return np.array( + [ + [c + x * x * one_c, x * y * one_c - z * s, x * z * one_c + y * s], + [y * x * one_c + z * s, c + y * y * one_c, y * z * one_c - x * s], + [z * x * one_c - y * s, z * y * one_c + x * s, c + z * z * one_c], + ], + dtype=np.float64, + ) + + +def _orthogonal_axis(vector: np.ndarray) -> np.ndarray: + axis = np.array([1.0, 0.0, 0.0], dtype=np.float64) + if abs(float(np.dot(vector, axis))) > 0.9: + axis = np.array([0.0, 1.0, 0.0], dtype=np.float64) + return _normalize(np.cross(vector, axis)) + + +def _normalize(vector: np.ndarray) -> np.ndarray: + norm = float(np.linalg.norm(vector)) + if norm == 0.0: + return vector + return vector / norm diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/managers/simready_manager/schemas.py b/embodichain/gen_sim/prompt2scene/agent_tools/managers/simready_manager/schemas.py new file mode 100644 index 000000000..86ae22b0a --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/managers/simready_manager/schemas.py @@ -0,0 +1,58 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path + + +@dataclass(frozen=True) +class MakeAssetSimreadyRequest: + """Request to prepare a general asset GLB for simulation placement.""" + + input_path: Path + output_path: Path + input_up_axis: list[float] | None = None + up_axis: list[float] | None = None + ground_clearance: float = 0.01 + + +@dataclass(frozen=True) +class MakeAssetSimreadyResult: + """Result of making an asset simulation-ready.""" + + output_path: Path + transform_matrix: list[list[float]] + + +@dataclass(frozen=True) +class MakeTableSimreadyRequest: + """Request to prepare a generated table GLB for simulation placement.""" + + input_path: Path + output_path: Path + input_up_axis: list[float] | None = None + up_axis: list[float] | None = None + ground_clearance: float = 0.01 + + +@dataclass(frozen=True) +class MakeTableSimreadyResult: + """Result of making a table simulation-ready.""" + + output_path: Path + transform_matrix: list[list[float]] diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/managers/simulation_manager/__init__.py b/embodichain/gen_sim/prompt2scene/agent_tools/managers/simulation_manager/__init__.py new file mode 100644 index 000000000..9441c6b80 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/managers/simulation_manager/__init__.py @@ -0,0 +1,31 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from embodichain.gen_sim.prompt2scene.agent_tools.managers.simulation_manager.manager import ( + SimulationManager, +) +from embodichain.gen_sim.prompt2scene.agent_tools.managers.simulation_manager.schemas import ( + GravityDropRequest, + GravityDropResult, +) + +__all__ = [ + "GravityDropRequest", + "GravityDropResult", + "SimulationManager", +] diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/managers/simulation_manager/manager.py b/embodichain/gen_sim/prompt2scene/agent_tools/managers/simulation_manager/manager.py new file mode 100644 index 000000000..4a0721103 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/managers/simulation_manager/manager.py @@ -0,0 +1,124 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +"""Simulation manager for gravity-based asset placement.""" + +from __future__ import annotations + +from pathlib import Path + +import numpy as np +import torch +import trimesh + +from embodichain.lab.sim.cfg import RigidObjectCfg +from embodichain.lab.sim.shapes import MeshCfg +from embodichain.lab.sim.sim_manager import ( + SimulationManager as _EmbodiSimManager, + SimulationManagerCfg, +) +from embodichain.gen_sim.prompt2scene.agent_tools.managers.simulation_manager.schemas import ( + GravityDropRequest, + GravityDropResult, +) + +__all__ = ["SimulationManager"] + + +class SimulationManager: + """Manager for gravity-based asset placement. + + Wraps an EmbodiChain simulation instance with typed request/response + methods, following the same pattern as service clients. + """ + + def __init__( + self, + *, + headless: bool = True, + physics_dt: float = 0.01, + sim_device: str = "cpu", + ) -> None: + """Initialize the simulation manager. + + Args: + headless: Whether to run without a GUI. + physics_dt: Physics timestep in seconds. + sim_device: Device to run the simulation on. + """ + self._headless = headless + self._physics_dt = physics_dt + self._sim_device = sim_device + + def run_gravity_simulation( + self, request: GravityDropRequest + ) -> GravityDropResult: + """Drop one GLB under gravity and return its final pose.""" + glb_path = request.glb_path.expanduser().resolve() + if not glb_path.is_file(): + raise FileNotFoundError(f"GLB file not found: {glb_path}") + + initial_height = ( + float(request.initial_height) + if request.initial_height is not None + else self._compute_adaptive_drop_height(glb_path) + ) + sim = _EmbodiSimManager( + SimulationManagerCfg( + headless=self._headless, + physics_dt=self._physics_dt, + sim_device=self._sim_device, + ) + ) + obj = sim.add_rigid_object( + RigidObjectCfg( + uid="dropped_asset", + shape=MeshCfg(fpath=str(glb_path)), + init_pos=(0.0, 0.0, initial_height), + init_rot=(0.0, 0.0, 0.0), + body_type="dynamic", + max_convex_hull_num=request.max_convex_hull_num, + ) + ) + sim.update(step=300) + + final_pose = obj.get_local_pose(to_matrix=True)[0].detach().cpu() + sim._deferred_destroy() + return GravityDropResult( + final_pose=np.asarray(final_pose.numpy(), dtype=float), + ) + + def _compute_adaptive_drop_height( + self, + glb_path: Path, + *, + min_clearance: float = 0.2, + height_scale: float = 1.25, + ) -> float: + """Compute an initial drop height from a GLB bounding box.""" + if min_clearance < 0.0: + raise ValueError("min_clearance must be non-negative.") + if height_scale <= 0.0: + raise ValueError("height_scale must be positive.") + + glb_path = glb_path.expanduser().resolve() + loaded = trimesh.load(glb_path, force=None) + if isinstance(loaded, trimesh.Scene): + bounds = loaded.bounds + else: + bounds = loaded.bounds + height = float(bounds[1][2] - bounds[0][2]) + return max(height * height_scale, height + min_clearance) diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/managers/simulation_manager/schemas.py b/embodichain/gen_sim/prompt2scene/agent_tools/managers/simulation_manager/schemas.py new file mode 100644 index 000000000..c9df4a526 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/managers/simulation_manager/schemas.py @@ -0,0 +1,42 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +__all__ = [ + "GravityDropRequest", + "GravityDropResult", +] + + +@dataclass(frozen=True) +class GravityDropRequest: + """Request to drop a GLB asset under gravity simulation.""" + + glb_path: Path + max_convex_hull_num: int = 32 + initial_height: float | None = None + + +@dataclass(frozen=True) +class GravityDropResult: + """Result of dropping a GLB asset under gravity.""" + + final_pose: Any diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/managers/table_clutter_fit_manager/__init__.py b/embodichain/gen_sim/prompt2scene/agent_tools/managers/table_clutter_fit_manager/__init__.py new file mode 100644 index 000000000..0819a0d37 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/managers/table_clutter_fit_manager/__init__.py @@ -0,0 +1,23 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from embodichain.gen_sim.prompt2scene.agent_tools.managers.table_clutter_fit_manager.manager import ( + fit_table_to_clutter, +) + +__all__ = ["fit_table_to_clutter"] diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/managers/table_clutter_fit_manager/manager.py b/embodichain/gen_sim/prompt2scene/agent_tools/managers/table_clutter_fit_manager/manager.py new file mode 100644 index 000000000..987e14878 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/managers/table_clutter_fit_manager/manager.py @@ -0,0 +1,298 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + + +from __future__ import annotations + +import json +import tempfile +from pathlib import Path +from typing import Any + +import numpy as np + +from embodichain.gen_sim.prompt2scene.utils.io import relative_path +from embodichain.gen_sim.prompt2scene.agent_tools.managers.geometry_manager.scene_geometry import ( + _copy_scene_with_transform, + _scene_to_mesh, + _z_up_to_glb_y_up_transform, + _detect_table_fit_support_quad, + _load_table_fit_scene_internal_z, + _table_fit_bounds_xy_manifest, + _table_fit_safe_positive_ratio, + _table_fit_scene_union_bounds, + _table_fit_uniform_xy_scale_transform, +) +from embodichain.gen_sim.prompt2scene.agent_tools.managers.simulation_manager import ( + SimulationManager, +) +from embodichain.gen_sim.prompt2scene.agent_tools.managers.simulation_manager.schemas import ( + GravityDropRequest, +) + +__all__ = ["fit_table_to_clutter"] + + +def _resolve_generated_path(value: Any, output_root: Path) -> Path: + if not value: + return Path() + path = Path(str(value)).expanduser() + if path.is_absolute(): + return path.resolve() + return (output_root.expanduser().resolve() / path).resolve() + + +def _gravity_settle_table_fit_internal_z_scene( + scene: Any, + *, + z_to_y: np.ndarray, + sim_device: str, +) -> Any: + sim = SimulationManager(headless=True, sim_device=sim_device) + with tempfile.TemporaryDirectory(prefix="p2s_table_fit_gravity_") as tmp: + tmp_path = Path(tmp) + pre_gravity = tmp_path / "table_pre_gravity.glb" + _copy_scene_with_transform(scene, z_to_y).export(pre_gravity) + result = sim.run_gravity_simulation( + GravityDropRequest( + glb_path=pre_gravity, + max_convex_hull_num=16, + initial_height=0.05, + ) + ) + settled = scene.copy() + settled.apply_transform(np.asarray(result.final_pose, dtype=np.float64)) + return settled + + +def _write_table_fit_json(path: Path, data: dict[str, Any]) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text( + json.dumps(data, ensure_ascii=False, indent=2) + "\n", + encoding="utf-8", + ) + + +def fit_table_to_clutter( + *, + table_result: dict[str, Any], + clutter_result: dict[str, Any], + output_root: Path, + output_dir: Path, + margin_cm: float = 10.0, + support_occupancy_ratio: float = 0.80, + gravity_settle_table: bool = True, + sim_device: str = "cpu", +) -> dict[str, Any]: + """Fit a table mesh to an already laid-out clutter result.""" + try: + import trimesh + except ImportError as exc: + raise RuntimeError("Table fitting requires trimesh.") from exc + + output_root = output_root.expanduser().resolve() + output_dir = output_dir.expanduser().resolve() + output_dir.mkdir(parents=True, exist_ok=True) + + # Resolve the table geometry. + table_simready_path = _resolve_generated_path( + table_result.get("simready_geometry_path") or table_result.get("mesh_path"), + output_root, + ) + if not table_simready_path.is_file(): + raise FileNotFoundError(f"Table simready GLB not found: {table_simready_path}") + + # Resolve the clutter object geometries. + settled_objects = [ + item + for item in clutter_result.get("objects", []) + if isinstance(item, dict) and item.get("status") == "ok" + ] + if not settled_objects: + raise ValueError("No successfully settled objects for table fitting.") + + object_glb_paths: list[tuple[str, Path]] = [] + for item in settled_objects: + glb_path = _resolve_generated_path( + item.get("laid_out_glb_path") or item.get("settled_glb_path"), + output_root, + ) + if glb_path.is_file(): + object_glb_paths.append((str(item["id"]), glb_path)) + + if not object_glb_paths: + raise ValueError("No valid settled object GLBs for table fitting.") + + z_to_y = _z_up_to_glb_y_up_transform() + y_to_z = np.linalg.inv(z_to_y) + + # Load the table and detect its support surface. + table_scene = _load_table_fit_scene_internal_z( + table_simready_path, + trimesh=trimesh, + y_to_z=y_to_z, + ) + table_mesh = _scene_to_mesh(table_scene, trimesh=trimesh) + clutter_aabb = clutter_result.get("clutter_2d_aabb_cm") or {} + clutter_size = clutter_aabb.get("size_xy", [1.0, 1.0]) + target_aspect = float(clutter_size[0]) / max(float(clutter_size[1]), 1.0e-6) + initial_support = _detect_table_fit_support_quad( + table_mesh, + target_aspect=target_aspect, + ) + + # Load the clutter scenes. + clutter_scenes = [ + (oid, _load_table_fit_scene_internal_z(path, trimesh=trimesh, y_to_z=y_to_z)) + for oid, path in object_glb_paths + ] + clutter_bounds = _table_fit_scene_union_bounds( + [scene for _, scene in clutter_scenes], + trimesh=trimesh, + ) + + # Compute the required table size and uniform scale. + clutter_size_cm = (clutter_bounds[1, :2] - clutter_bounds[0, :2]) * 100.0 + occupancy = float(np.clip(support_occupancy_ratio, 0.1, 1.0)) + required_size_cm = clutter_size_cm / occupancy + 2.0 * float(margin_cm) + support_size_cm = np.asarray(initial_support["size_xy"], dtype=np.float64) * 100.0 + scale_x = _table_fit_safe_positive_ratio(required_size_cm[0], support_size_cm[0]) + scale_y = _table_fit_safe_positive_ratio(required_size_cm[1], support_size_cm[1]) + uniform_scale = max(scale_x, scale_y) + table_scale_transform = _table_fit_uniform_xy_scale_transform( + center_xy=np.asarray(initial_support["center_xy"], dtype=np.float64), + scale=uniform_scale, + ) + table_scene.apply_transform(table_scale_transform) + + # Settle the table under gravity. + if gravity_settle_table: + table_scene = _gravity_settle_table_fit_internal_z_scene( + table_scene, + z_to_y=z_to_y, + sim_device=sim_device, + ) + + # Reposition the table at the origin. + final_table_mesh = _scene_to_mesh(table_scene, trimesh=trimesh) + final_support = _detect_table_fit_support_quad( + final_table_mesh, + target_aspect=float(required_size_cm[0] / max(required_size_cm[1], 1.0e-6)), + ) + support_center = np.asarray(final_support["center"], dtype=np.float64) + table_bounds = np.asarray(final_table_mesh.bounds, dtype=np.float64) + table_bottom_z = float(table_bounds[0, 2]) + + table_shift = np.eye(4, dtype=np.float64) + table_shift[:3, 3] = [-support_center[0], -support_center[1], -table_bottom_z] + table_scene.apply_transform(table_shift) + support_z_after = float((support_center + table_shift[:3, 3])[2]) + + # Measure the table surface height. + # Use the highest point of the table mesh (after scaling + gravity + shift) + # rather than the support-plane mean Z, so that thin objects sit above the + # actual geometry even when the tabletop has slight unevenness. + _table_mesh_after_shift = _scene_to_mesh(table_scene, trimesh=trimesh) + _table_max_z = float( + np.asarray(_table_mesh_after_shift.bounds, dtype=np.float64)[1, 2] + ) + _surface_z_margin = 0.01 # 1 cm above the highest table point + + # Place the objects on the table. + placed_objects: list[dict[str, Any]] = [] + shifted_clutter: list[tuple[str, Any]] = [] + clutter_after = _table_fit_scene_union_bounds( + [scene for _, scene in clutter_scenes], + trimesh=trimesh, + ) + clutter_center_xy = 0.5 * (clutter_after[0, :2] + clutter_after[1, :2]) + for oid, scene in clutter_scenes: + obj_mesh = _scene_to_mesh(scene, trimesh=trimesh) + obj_bounds = np.asarray(obj_mesh.bounds, dtype=np.float64) + obj_bottom_z = float(obj_bounds[0, 2]) + obj_shift = np.eye(4, dtype=np.float64) + obj_shift[:3, 3] = [ + -float(clutter_center_xy[0]), + -float(clutter_center_xy[1]), + _table_max_z - obj_bottom_z + _surface_z_margin, + ] + scene.apply_transform(obj_shift) + shifted_clutter.append((oid, scene)) + + # Export the fitted table and placed objects. + final_table_path = output_dir / "table_fit_to_clutter.glb" + _copy_scene_with_transform(table_scene, z_to_y).export(final_table_path) + + for oid, scene in shifted_clutter: + object_path = output_dir / f"{oid}_on_table.glb" + _copy_scene_with_transform(scene, z_to_y).export(object_path) + placed_objects.append({"id": oid, "path": str(object_path)}) + + # Write the fit manifest. + final_clutter_bounds = _table_fit_scene_union_bounds( + [scene for _, scene in shifted_clutter], + trimesh=trimesh, + ) + final_clutter_aabb_cm = _table_fit_bounds_xy_manifest( + final_clutter_bounds, + unit_scale=100.0, + ) + final_support_centered = { + **final_support, + "center": (support_center + table_shift[:3, 3]).tolist(), + "center_xy": ( + np.asarray(final_support["center_xy"], dtype=np.float64) + - support_center[:2] + ).tolist(), + "corners_xy": ( + np.asarray(final_support["corners_xy"], dtype=np.float64) + - support_center[:2] + ).tolist(), + } + manifest = { + "status": "ok", + "output_dir": str(output_dir), + "table_simready_path": str(table_simready_path), + "table_output_path": str(final_table_path), + "objects": placed_objects, + "margin_cm": margin_cm, + "support_occupancy_ratio": occupancy, + "gravity_settle_table": gravity_settle_table, + "table_bottom_z_after_shift": 0.0, + "support_z_after_shift": support_z_after, + "initial_support_quad": initial_support, + "final_support_quad_centered": final_support_centered, + "clutter_2d_aabb_cm": final_clutter_aabb_cm, + "required_support_size_cm": required_size_cm.tolist(), + "table_xy_scale": { + "uniform_scale": uniform_scale, + "scale_x_raw": scale_x, + "scale_y_raw": scale_y, + "support_size_before_scale_cm": support_size_cm.tolist(), + }, + "fit_check": { + "fits_width": float(final_clutter_aabb_cm["size_xy"][0]) + <= float(np.asarray(final_support_centered["size_xy"])[0] * 100.0), + "fits_depth": float(final_clutter_aabb_cm["size_xy"][1]) + <= float(np.asarray(final_support_centered["size_xy"])[1] * 100.0), + }, + } + manifest_path = output_dir / "table_fit_to_clutter_manifest.json" + _write_table_fit_json(manifest_path, manifest) + return { + "status": "ok", + "manifest_path": relative_path(manifest_path, output_root), + } diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/managers/text_layout_manager/__init__.py b/embodichain/gen_sim/prompt2scene/agent_tools/managers/text_layout_manager/__init__.py new file mode 100644 index 000000000..ce2215329 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/managers/text_layout_manager/__init__.py @@ -0,0 +1,33 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from embodichain.gen_sim.prompt2scene.agent_tools.managers.text_layout_manager.layout import ( + _layout_text_objects_grid, +) +from embodichain.gen_sim.prompt2scene.agent_tools.managers.text_layout_manager.optimization import ( + _optimize_text_layout_slp, +) +from embodichain.gen_sim.prompt2scene.agent_tools.managers.text_layout_manager.settle import ( + settle_text_objects_to_ground, +) + +__all__ = [ + "_layout_text_objects_grid", + "_optimize_text_layout_slp", + "settle_text_objects_to_ground", +] diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/managers/text_layout_manager/layout.py b/embodichain/gen_sim/prompt2scene/agent_tools/managers/text_layout_manager/layout.py new file mode 100644 index 000000000..7b94a852e --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/managers/text_layout_manager/layout.py @@ -0,0 +1,383 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + + +from __future__ import annotations + +from typing import Any + +import numpy as np + +from embodichain.gen_sim.prompt2scene.agent_tools.managers.optimization_manager import ( + _center_xy_aabb_layout, +) +from embodichain.gen_sim.prompt2scene.agent_tools.managers.text_layout_manager.optimization import ( + _optimize_text_layout_slp, +) +__all__ = [ + "_layout_text_objects_grid", +] + +def _transitive_closure( + nodes: list[str], + edges: list[tuple[str, str]], +) -> list[tuple[str, str]]: + """Floyd–Warshall transitive closure over a small set of nodes.""" + if not nodes or not edges: + return list(edges) + idx = {n: i for i, n in enumerate(nodes)} + n = len(nodes) + adj = [[False] * n for _ in range(n)] + for src, dst in edges: + if src in idx and dst in idx: + adj[idx[src]][idx[dst]] = True + for k in range(n): + for i in range(n): + if adj[i][k]: + row_k = adj[k] + row_i = adj[i] + for j in range(n): + if row_k[j]: + row_i[j] = True + closed: list[tuple[str, str]] = [] + for i in range(n): + for j in range(n): + if adj[i][j]: + closed.append((nodes[i], nodes[j])) + return closed + + + +def _longest_path_ranks( + nodes: list[str], + edges: list[tuple[str, str]], +) -> dict[str, int]: + """Assign integer ranks satisfying ``(A,B)`` → rank[A] < rank[B]. + + Uses topological sort + longest-path DP. Returns a rank dict for every + node in *nodes* (default 0 for isolated nodes). + """ + ranks: dict[str, int] = {n: 0 for n in nodes} + if not edges: + return ranks + # Build adjacency and in-degree + adj: dict[str, list[str]] = {n: [] for n in nodes} + in_deg: dict[str, int] = {n: 0 for n in nodes} + present = set(nodes) + for src, dst in edges: + if src not in present or dst not in present: + continue + adj[src].append(dst) + in_deg[dst] += 1 + # Kahn topological sort + queue = [n for n in nodes if in_deg[n] == 0] + order: list[str] = [] + while queue: + u = queue.pop(0) + order.append(u) + for v in adj[u]: + in_deg[v] -= 1 + if in_deg[v] == 0: + queue.append(v) + # Longest path + for u in order: + for v in adj[u]: + if ranks[v] < ranks[u] + 1: + ranks[v] = ranks[u] + 1 + # Remaining nodes (cycles / isolated) keep rank 0 + return ranks + + + +def _layout_text_objects_grid( + *, + object_ids: list[str], + xy_sizes: dict[str, np.ndarray], + spatial_relations: list[dict[str, Any]], + table_constraints: list[dict[str, Any]] | None = None, + grid_spacing: float = 0.02, + padding_ratio: float = 0.08, +) -> dict[str, Any]: + """Lay out text-scene objects — transitive closure + longest-path ranks. + + 1. Transitive closure of left_of / front_of. + 2. Pick centre: explicit 9‑grid ʻcenterʼ, else highest-degree node. + 3. Longest-path rank assignment (left_of→X, front_of→Y). + 4. Shift 9‑grid anchors to their grid positions. + 5. Free objects auto‑wrap below. + 6. Convert ranks→XY using per‑column/row max sizes + gaps. + 7. SA point optimisation + mesh AABB collision cleanup. + """ + if not object_ids: + return { + "centers": {}, + "initial_centers": {}, + "metadata": { + "method": "transitive_closure_longest_path_with_9grid", + "iterations": 0, + }, + } + + # Parse spatial relations. + left_of_edges: list[tuple[str, str]] = [] + front_of_edges: list[tuple[str, str]] = [] + seen: set[tuple[str, str, str]] = set() + for rel in spatial_relations: + subject = str(rel.get("subject") or "") + obj = str(rel.get("object") or "") + relation = str(rel.get("relation") or "") + if not subject or not obj or subject == obj: + continue + key = (subject, relation, obj) + if key in seen: + continue + seen.add(key) + if relation == "left_of": + left_of_edges.append((subject, obj)) + elif relation == "front_of": + front_of_edges.append((subject, obj)) + + # Compute transitive closures. + left_of_closed = _transitive_closure(object_ids, left_of_edges) + front_of_closed = _transitive_closure(object_ids, front_of_edges) + + # Parse nine-grid constraints. + # −Y = front, so front row = 0, back row = 2 + _GRID_TO_RC: dict[str, tuple[int, int]] = { + "left_front": (0, 0), "center_front": (1, 0), "right_front": (2, 0), + "left_center": (0, 1), "center": (1, 1), "right_center": (2, 1), + "left_back": (0, 2), "center_back": (1, 2), "right_back": (2, 2), + "front": (1, 0), "back": (1, 2), + "left": (0, 1), "right": (2, 1), + } + grid_targets: dict[str, tuple[int, int]] = {} + for tc in (table_constraints or []): + asset = str(tc.get("asset") or "") + grid_name = str(tc.get("grid") or "").strip() + if asset in object_ids and grid_name in _GRID_TO_RC: + grid_targets[asset] = _GRID_TO_RC[grid_name] + + # Select a center object when none is explicit. + auto_center_oid: str | None = None + has_explicit_center = any( + tc.get("grid") == "center" for tc in (table_constraints or []) + ) + if not has_explicit_center: + # Degree = appearances in left_of + front_of (subject or object) + degree: dict[str, int] = {oid: 0 for oid in object_ids} + for src, dst in left_of_closed + front_of_closed: + if src in degree: + degree[src] += 1 + if dst in degree: + degree[dst] += 1 + max_deg = max(degree.values()) if degree else 0 + if max_deg > 0: + candidates = [oid for oid, d in degree.items() if d == max_deg] + # Tie-breaker: largest AABB area + centre_oid = max( + candidates, + key=lambda oid: float(xy_sizes[oid][0]) * float(xy_sizes[oid][1]), + ) + grid_targets[centre_oid] = (1, 1) # 9‑grid centre + auto_center_oid = centre_oid + + # Derive ranks from the transitive closures. + x_rank = _longest_path_ranks(object_ids, left_of_closed) + # −Y = front: A front_of B → A.y < B.y → row[A] < row[B]. + # _longest_path_ranks gives rank[src] < rank[dst]; edges are + # already (A,B) for "A front_of B", so NO reversal needed. + y_rank = _longest_path_ranks(object_ids, front_of_closed) + + # Apply nine-grid shifts. + # Pin 9‑grid objects to their target ranks; shift all connected + # objects (both upstream and downstream) to preserve topology. + if grid_targets: + # Build undirected connected-components via relation edges + all_edges = left_of_closed + front_of_closed + neighbours: dict[str, set[str]] = {oid: set() for oid in object_ids} + for src, dst in all_edges: + if src in neighbours and dst in neighbours: + neighbours[src].add(dst) + neighbours[dst].add(src) + for oid in grid_targets: + neighbours.setdefault(oid, set()) + + # For each 9‑grid object, BFS the component and shift uniformly + shifted: set[str] = set() + for oid, (target_col, target_row) in grid_targets.items(): + if oid in shifted: + continue + dx = target_col - x_rank.get(oid, 0) + dy = target_row - y_rank.get(oid, 0) + + # BFS to collect the full connected component + component: set[str] = {oid} + queue = [oid] + while queue: + u = queue.pop(0) + for v in neighbours.get(u, set()): + if v not in component: + component.add(v) + queue.append(v) + + for oid2 in component: + if oid2 not in grid_targets: # only shift non‑anchored objects + x_rank[oid2] = x_rank.get(oid2, 0) + dx + y_rank[oid2] = y_rank.get(oid2, 0) + dy + shifted.update(component) + + # Propagate row and column alignment. + # left_of A B → same row (y_rank[A] = y_rank[B]) + # front_of A B → same col (x_rank[A] = x_rank[B]) + # Priority (higher wins): 9‑grid > higher degree > larger area. + _prio = { + oid: ( + oid in grid_targets, + sum(1 for e in left_of_closed + front_of_closed if oid in e), + float(xy_sizes[oid][0]) * float(xy_sizes[oid][1]), + ) + for oid in object_ids + } + for src, dst in left_of_closed: + if _prio[src] >= _prio[dst]: + y_rank[dst] = y_rank.get(src, 0) + else: + y_rank[src] = y_rank.get(dst, 0) + for src, dst in front_of_closed: + if _prio[src] >= _prio[dst]: + x_rank[dst] = x_rank.get(src, 0) + else: + x_rank[src] = x_rank.get(dst, 0) + + # Normalise to >= 0 + min_x = min(x_rank.values()) if x_rank else 0 + min_y = min(y_rank.values()) if y_rank else 0 + for oid in object_ids: + x_rank[oid] = x_rank.get(oid, 0) - min_x + y_rank[oid] = y_rank.get(oid, 0) - min_y + + # Resolve cell collisions: spread objects sharing the same (col, row) + cell_occupants: dict[tuple[int, int], list[str]] = {} + for oid in object_ids: + cell = (x_rank[oid], y_rank[oid]) + cell_occupants.setdefault(cell, []).append(oid) + for (col, row), occupants in cell_occupants.items(): + if len(occupants) > 1: + for offset, oid in enumerate(occupants[1:], start=1): + x_rank[oid] = col + offset + + # Place unconstrained objects in wrapped rows. + constrained = set() + for src, dst in left_of_closed + front_of_closed: + constrained.update([src, dst]) + constrained.update(grid_targets) + free_objects = [oid for oid in object_ids if oid not in constrained] + + if free_objects: + free_row = max(y_rank.values()) + 1 if y_rank else 0 + # Max row width ≈ existing union width × 1.5 (at least 3 cols) + col_keys = list(x_rank.values()) + existing_cols = max(col_keys) - min(col_keys) + 1 if col_keys else 1 + max_cols_per_row = max(existing_cols, 3) + free_sorted = sorted( + free_objects, + key=lambda oid: float(xy_sizes[oid][0]), + reverse=True, + ) + col = 0 + row_offset = 0 + for oid in free_sorted: + x_rank[oid] = col + y_rank[oid] = free_row + row_offset + col += 1 + if col >= max_cols_per_row: + col = 0 + row_offset += 1 + + # Convert ranks to XY positions. + col_widths: dict[int, float] = {} + row_heights: dict[int, float] = {} + for oid in object_ids: + c = x_rank[oid] + r = y_rank[oid] + col_widths[c] = max(col_widths.get(c, 0.0), float(xy_sizes[oid][0])) + row_heights[r] = max(row_heights.get(r, 0.0), float(xy_sizes[oid][1])) + + x_cumsum: dict[int, float] = {} + cumulative = 0.0 + for c in sorted(col_widths): + x_cumsum[c] = cumulative + cumulative += col_widths[c] + grid_spacing + + y_cumsum: dict[int, float] = {} + cumulative = 0.0 + for r in sorted(row_heights): + y_cumsum[r] = cumulative + cumulative += row_heights[r] + grid_spacing + + centers: dict[str, np.ndarray] = {} + for oid in object_ids: + c = x_rank[oid] + r = y_rank[oid] + cx = x_cumsum[c] + 0.5 * float(xy_sizes[oid][0]) + cy = y_cumsum[r] + 0.5 * float(xy_sizes[oid][1]) + centers[oid] = np.array([cx, cy], dtype=np.float64) + + centers = _center_xy_aabb_layout(centers=centers, xy_sizes=xy_sizes) + + initial_centers = {oid: c.copy() for oid, c in centers.items()} + + # Snap initial grid positions as 9‑grid spring targets + grid_spring_targets: dict[str, np.ndarray] = { + oid: initial_centers[oid].copy() + for oid in grid_targets + if oid in initial_centers + } + + # Optimize positions and remove mesh AABB collisions. + optimized = _optimize_text_layout_slp( + object_ids=object_ids, + xy_sizes=xy_sizes, + initial_centers=initial_centers, + left_of_edges=left_of_closed, + front_of_edges=front_of_closed, + grid_spring_targets=grid_spring_targets, + padding_ratio=padding_ratio, + ) + centers = optimized["centers"] + optimization_metadata = optimized["metadata"] + + # Collect layout metadata. + metadata = { + "method": "transitive_closure_longest_path_with_9grid_and_sa", + "grid_spacing": grid_spacing, + "auto_center_oid": auto_center_oid, + "has_explicit_center": has_explicit_center, + "table_constraint_count": len(grid_targets), + "left_of_count": len(left_of_edges), + "left_of_closed_count": len(left_of_closed), + "front_of_count": len(front_of_edges), + "front_of_closed_count": len(front_of_closed), + "free_object_count": len(free_objects), + "x_ranks": {oid: x_rank.get(oid, 0) for oid in object_ids}, + "y_ranks": {oid: y_rank.get(oid, 0) for oid in object_ids}, + "optimization": optimization_metadata, + } + return { + "centers": centers, + "initial_centers": initial_centers, + "metadata": metadata, + } diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/managers/text_layout_manager/optimization.py b/embodichain/gen_sim/prompt2scene/agent_tools/managers/text_layout_manager/optimization.py new file mode 100644 index 000000000..b8915fc4c --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/managers/text_layout_manager/optimization.py @@ -0,0 +1,404 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from typing import Any + +import numpy as np +from scipy.optimize import minimize + +from embodichain.gen_sim.prompt2scene.agent_tools.managers.optimization_manager import ( + _center_xy_aabb_layout, + _footprint_layout_diagnostics, + _xy_aabb_overlap, + _xy_union_bounds, +) + +__all__ = ["_optimize_text_layout_slp"] + +# SLSQP solve options — matching the original example_optimization SA pipeline. +_SLSQP_OPTIONS: dict[str, Any] = {"maxiter": 500, "ftol": 1e-6, "disp": False} + +# Objective weights (relations are hard constraints, not in the objective). +_WEIGHTS: dict[str, float] = { + "seed": 1.0, + "overlap": 200.0, + "grid": 100.0, +} + + +def _optimize_text_layout_slp( + *, + object_ids: list[str], + xy_sizes: dict[str, np.ndarray], + initial_centers: dict[str, np.ndarray], + left_of_edges: list[tuple[str, str]], + front_of_edges: list[tuple[str, str]], + grid_spring_targets: dict[str, np.ndarray], + padding_ratio: float, +) -> dict[str, Any]: + """Optimize 2D centres with scipy SLSQP, then remove mesh AABB overlap. + + Mirroring the original example_optimization/SA pipeline: + - left_of / front_of → linear inequality constraints + - bounding box → variable bounds (2× initial union) + - seed / overlap / grid → soft penalties in the objective + - post‑solve collision cleanup on actual footprint AABBs + """ + if not object_ids: + return { + "centers": {}, + "metadata": { + "method": "text_slsqp_then_mesh_aabb_collision_removal", + "slsqp_iterations": 0, + "collision_iterations": 0, + }, + } + + max_extent = max( + float(max(xy_sizes[oid][0], xy_sizes[oid][1])) for oid in object_ids + ) + padding = max(max_extent * padding_ratio, 1e-3) + + initial_centers = { + oid: np.asarray(initial_centers[oid], dtype=np.float64).copy() + for oid in object_ids + } + initial_union_bounds = _xy_union_bounds( + centers=initial_centers, + xy_sizes=xy_sizes, + ) + + index_by_id = {oid: i for i, oid in enumerate(object_ids)} + x0 = _pack_centers(object_ids, initial_centers) + + # Build linear inequality constraints for left_of and front_of. + constraints: list[dict[str, Any]] = [] + _build_relation_constraints( + constraints=constraints, + object_ids=object_ids, + index_by_id=index_by_id, + xy_sizes=xy_sizes, + left_of_edges=left_of_edges, + front_of_edges=front_of_edges, + padding=padding, + ) + + # Bound variables to twice the initial union size. + init_size = initial_union_bounds[1] - initial_union_bounds[0] + margin = init_size * 0.5 # 50 % each side → 2× total + bounds = [] + for oid in object_ids: + bounds.append( + ( + float(initial_union_bounds[0, 0] - margin[0]), + float(initial_union_bounds[1, 0] + margin[0]), + ) + ) # x + bounds.append( + ( + float(initial_union_bounds[0, 1] - margin[1]), + float(initial_union_bounds[1, 1] + margin[1]), + ) + ) # y + + # Define the optimization objective. + def _objective(xvec: np.ndarray) -> float: + centers = _unpack_centers(object_ids, xvec) + loss = 0.0 + + # seed: stay close to initial positions + for oid in object_ids: + delta = centers[oid] - initial_centers[oid] + loss += _WEIGHTS["seed"] * float(np.dot(delta, delta)) + + # overlap: AABB overlap area penalty + for i, oid in enumerate(object_ids): + for other_id in object_ids[i + 1 :]: + ov = _xy_aabb_overlap( + center_a=centers[oid], + size_a=xy_sizes[oid], + center_b=centers[other_id], + size_b=xy_sizes[other_id], + padding=padding, + ) + if ov is not None: + loss += _WEIGHTS["overlap"] * float(ov[0] * ov[1]) + + # grid: spring toward 9‑grid targets + for oid, target in grid_spring_targets.items(): + if oid not in centers: + continue + delta = centers[oid] - target + loss += _WEIGHTS["grid"] * float(np.dot(delta, delta)) + + return float(loss) + + # Solve the constrained optimization problem. + slsqp_result: dict[str, Any] = {"success": False, "nit": 0, "message": ""} + try: + result = minimize( + _objective, + x0, + method="SLSQP", + bounds=bounds, + constraints=constraints, + options=_SLSQP_OPTIONS, + ) + slsqp_result = { + "success": bool(result.success), + "nit": int(getattr(result, "nit", 0)), + "message": str(result.message), + "fun": float(result.fun) if result.fun is not None else None, + } + if result.success: + x_opt = result.x + else: + # SLSQP failed — fall back to seed positions + x_opt = x0.copy() + except Exception: + x_opt = x0.copy() + slsqp_result["message"] = "SLSQP raised an exception; using seed positions." + + centers = _unpack_centers(object_ids, x_opt) + centers = _center_xy_aabb_layout(centers=centers, xy_sizes=xy_sizes) + + # Remove residual collisions. + centers, collision_metadata = _remove_mesh_aabb_collisions( + object_ids=object_ids, + xy_sizes=xy_sizes, + centers=centers, + initial_centers=initial_centers, + left_of_edges=left_of_edges, + front_of_edges=front_of_edges, + padding=padding, + ) + centers = _center_xy_aabb_layout(centers=centers, xy_sizes=xy_sizes) + + # Collect optimization metadata. + diagnostics = _footprint_layout_diagnostics( + object_ids=object_ids, + centers=centers, + initial_centers=initial_centers, + xy_sizes=xy_sizes, + padding=padding, + initial_union_bounds=initial_union_bounds, + ) + metadata: dict[str, Any] = { + "method": "text_slsqp_then_mesh_aabb_collision_removal", + "relation_usage": "left_of_front_of_hard_constraints", + "padding": float(padding), + "padding_ratio": float(padding_ratio), + "weights": dict(_WEIGHTS), + "slsqp": slsqp_result, + "bounds_expansion": 2.0, + "initial_union_size": init_size.tolist(), + **collision_metadata, + "final_centers": { + oid: centers[oid].tolist() for oid in object_ids + }, + **diagnostics, + } + return {"centers": centers, "metadata": metadata} + + +# Build relation constraints. + + +def _build_relation_constraints( + *, + constraints: list[dict[str, Any]], + object_ids: list[str], + index_by_id: dict[str, int], + xy_sizes: dict[str, np.ndarray], + left_of_edges: list[tuple[str, str]], + front_of_edges: list[tuple[str, str]], + padding: float, +) -> None: + """Append SLSQP inequality constraints for left_of / front_of edges.""" + + for subject, obj in left_of_edges: + if subject not in index_by_id or obj not in index_by_id: + continue + i_a = index_by_id[subject] + i_b = index_by_id[obj] + # A.x + gap ≤ B.x → B.x - A.x - gap ≥ 0 + gap = ( + 0.5 * float(xy_sizes[subject][0]) + + 0.5 * float(xy_sizes[obj][0]) + + padding + ) + constraints.append( + { + "type": "ineq", + "fun": lambda x, ia=i_a, ib=i_b, g=gap: float( + x[2 * ib] - x[2 * ia] - g + ), + } + ) + + for subject, obj in front_of_edges: + if subject not in index_by_id or obj not in index_by_id: + continue + i_a = index_by_id[subject] + i_b = index_by_id[obj] + # A.y + gap ≤ B.y → B.y - A.y - gap ≥ 0 + gap = ( + 0.5 * float(xy_sizes[subject][1]) + + 0.5 * float(xy_sizes[obj][1]) + + padding + ) + constraints.append( + { + "type": "ineq", + "fun": lambda x, ia=i_a, ib=i_b, g=gap: float( + x[2 * ib + 1] - x[2 * ia + 1] - g + ), + } + ) + + +# Remove AABB collisions. + + +def _remove_mesh_aabb_collisions( + *, + object_ids: list[str], + xy_sizes: dict[str, np.ndarray], + centers: dict[str, np.ndarray], + initial_centers: dict[str, np.ndarray], + left_of_edges: list[tuple[str, str]], + front_of_edges: list[tuple[str, str]], + padding: float, +) -> tuple[dict[str, np.ndarray], dict[str, Any]]: + relation_pairs = set(left_of_edges + front_of_edges) + relation_pairs.update((b, a) for a, b in left_of_edges + front_of_edges) + current = { + oid: np.asarray(center, dtype=np.float64).copy() + for oid, center in centers.items() + } + max_rounds = 80 + total_pushes = 0 + last_overlap_count = 0 + + for iteration in range(max_rounds): + overlaps = _mesh_aabb_collision_pairs( + object_ids=object_ids, + xy_sizes=xy_sizes, + centers=current, + padding=padding, + ) + last_overlap_count = len(overlaps) + if not overlaps: + return current, { + "collision_iterations": iteration, + "collision_pushes": total_pushes, + "collision_remaining": 0, + "collision_removal": "iterative_mesh_aabb_push", + } + for item in overlaps: + object_a = item["object"] + object_b = item["other"] + axis = int(item["axis"]) + sign = -1.0 if current[object_a][axis] <= current[object_b][axis] else 1.0 + amount = 0.5 * (float(item["overlap"]) + 1.0e-6) + if (object_a, object_b) in relation_pairs: + current[object_a][axis] += sign * amount + current[object_b][axis] -= sign * amount + else: + drift_a = np.linalg.norm( + current[object_a] - initial_centers[object_a] + ) + drift_b = np.linalg.norm( + current[object_b] - initial_centers[object_b] + ) + if drift_a <= drift_b: + current[object_a][axis] += sign * amount * 1.25 + current[object_b][axis] -= sign * amount * 0.75 + else: + current[object_a][axis] += sign * amount * 0.75 + current[object_b][axis] -= sign * amount * 1.25 + total_pushes += 1 + current = _center_xy_aabb_layout(centers=current, xy_sizes=xy_sizes) + + return current, { + "collision_iterations": max_rounds, + "collision_pushes": total_pushes, + "collision_remaining": last_overlap_count, + "collision_removal": "iterative_mesh_aabb_push", + } + + +def _mesh_aabb_collision_pairs( + *, + object_ids: list[str], + xy_sizes: dict[str, np.ndarray], + centers: dict[str, np.ndarray], + padding: float, +) -> list[dict[str, Any]]: + pairs: list[dict[str, Any]] = [] + for i, oid in enumerate(object_ids): + for other_id in object_ids[i + 1 :]: + ov = _xy_aabb_overlap( + center_a=centers[oid], + size_a=xy_sizes[oid], + center_b=centers[other_id], + size_b=xy_sizes[other_id], + padding=padding, + ) + if ov is None: + continue + axis = 0 if ov[0] <= ov[1] else 1 + pairs.append( + { + "object": oid, + "other": other_id, + "axis": axis, + "overlap": float(ov[axis]), + "overlap_x": float(ov[0]), + "overlap_y": float(ov[1]), + } + ) + pairs.sort(key=lambda item: item["overlap"], reverse=True) + return pairs + + +# Pack and unpack center coordinates. + + +def _pack_centers( + object_ids: list[str], + centers: dict[str, np.ndarray], +) -> np.ndarray: + values: list[float] = [] + for oid in object_ids: + c = np.asarray(centers[oid], dtype=np.float64) + values.extend([float(c[0]), float(c[1])]) + return np.asarray(values, dtype=np.float64) + + +def _unpack_centers( + object_ids: list[str], + xvec: np.ndarray, +) -> dict[str, np.ndarray]: + return { + oid: np.asarray( + [xvec[2 * i], xvec[2 * i + 1]], + dtype=np.float64, + ) + for i, oid in enumerate(object_ids) + } diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/managers/text_layout_manager/settle.py b/embodichain/gen_sim/prompt2scene/agent_tools/managers/text_layout_manager/settle.py new file mode 100644 index 000000000..da3cdde6e --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/managers/text_layout_manager/settle.py @@ -0,0 +1,429 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + + +from __future__ import annotations + +import tempfile +import traceback +from pathlib import Path +from typing import Any + +import numpy as np + +from embodichain.gen_sim.prompt2scene.agent_tools.managers.simulation_manager import ( + SimulationManager, +) +from embodichain.gen_sim.prompt2scene.agent_tools.managers.simulation_manager.schemas import ( + GravityDropRequest, +) +from embodichain.gen_sim.prompt2scene.agent_tools.managers.optimization_manager import ( + _object_scenes_xy_aabb_manifest, +) +from embodichain.gen_sim.prompt2scene.agent_tools.managers.geometry_manager.scene_geometry import ( + _aabb_bottom_to_xy_plane_transform, + _copy_scene_with_transform, + _matrix_from_json, + _scale_transform, + _scene_to_mesh, + _xy_aabb_center, + _xy_aabb_size, + _z_up_to_glb_y_up_transform, +) +from embodichain.gen_sim.prompt2scene.utils.io import ( + relative_path, + write_json, +) +from embodichain.gen_sim.prompt2scene.utils.log import log_warning +from embodichain.gen_sim.prompt2scene.agent_tools.managers.matplotlib_manager import ( + MatplotlibManager, + RenderFootprintLayoutRequest, +) +from embodichain.gen_sim.prompt2scene.agent_tools.managers.text_layout_manager.layout import ( + _layout_text_objects_grid, +) + +__all__ = ["settle_text_objects_to_ground"] + + +def settle_text_objects_to_ground( + *, + objects: list[dict[str, Any]], + spatial_relations: list[dict[str, Any]] | None = None, + table_constraints: list[dict[str, Any]] | None = None, + output_dir: Path, + output_root: Path, + sim_device: str = "cpu", +) -> dict[str, Any]: + """Scale simready objects to real-world size, gravity-settle, layout on table. + + For each text-input object: + 1. Load simready GLB (GLB Y-up) → convert to internal Z-up + 2. Apply scene-level metric scale_factor → real-world size + 3. Gravity simulation to settle on ground plane + 4. Move AABB bottom centre to XY origin at Z=0 + 5. Build grid/rank initialization from left_of/front_of and table constraints + 6. Run SA-based 2D point optimization and mesh AABB collision cleanup + 7. Apply layout positions + + Returns laid-out scenes and per-object metadata. + """ + try: + import trimesh + except ImportError as exc: + raise RuntimeError("Text object gravity settling requires trimesh.") from exc + + output_dir = output_dir.expanduser().resolve() + output_dir.mkdir(parents=True, exist_ok=True) + sim = SimulationManager(headless=True, sim_device=sim_device) + z_to_y = _z_up_to_glb_y_up_transform() + y_to_z = np.linalg.inv(z_to_y) + + settled_objects: list[dict[str, Any]] = [] + object_scenes: list[tuple[str, Any]] = [] + + with tempfile.TemporaryDirectory(prefix="p2s_text_settle_") as tmp_dir: + tmp_path = Path(tmp_dir) + for obj in objects: + obj_id = str(obj.get("id", "")) + obj_name = str(obj.get("name", "")) + + # Validate the metric scale. + metric_scale = obj.get("metric_scale") + if not isinstance(metric_scale, dict): + settled_objects.append( + { + "id": obj_id, + "name": obj_name, + "status": "skipped", + "reason": "missing_metric_scale", + } + ) + continue + scale_factor = float(metric_scale.get("scale_factor", 1.0)) + if not np.isfinite(scale_factor) or scale_factor <= 0.0: + settled_objects.append( + { + "id": obj_id, + "name": obj_name, + "status": "skipped", + "reason": "invalid_scale_factor", + } + ) + continue + + # Load the simulation-ready GLB. + simready_path = _resolve_generated_path( + obj.get("simready_geometry_path") or obj.get("mesh_path"), + output_root, + ) + if not simready_path.is_file(): + settled_objects.append( + { + "id": obj_id, + "name": obj_name, + "status": "skipped", + "reason": "missing_simready_glb", + } + ) + continue + + try: + # Load simready (GLB Y-up) → convert to internal Z-up + scene_yup = trimesh.load(simready_path, force="scene") + scene = _copy_scene_with_transform(scene_yup, y_to_z) + + # Apply real-world scale + scale_transform = _scale_transform(scale_factor) + scene.apply_transform(scale_transform) + + # Settle the object under gravity. + mesh = _scene_to_mesh(scene, trimesh=trimesh) + mesh_bounds = np.asarray(mesh.bounds, dtype=np.float64) + mesh_z_height = max(float(mesh_bounds[1][2] - mesh_bounds[0][2]), 0.0) + bottom_to_xy = _aabb_bottom_to_xy_plane_transform(mesh_bounds) + normalized_scene = _copy_scene_with_transform(scene, bottom_to_xy) + + # Export to Y-up GLB for gravity + pre_gravity_scene = _copy_scene_with_transform(normalized_scene, z_to_y) + pre_gravity_path = tmp_path / f"{obj_id}_pre_gravity.glb" + pre_gravity_scene.export(pre_gravity_path) + gravity_initial_height = mesh_z_height * 0.1 + + gravity_status = "ok" + gravity_transform = np.eye(4, dtype=np.float64) + gravity_reason = "" + try: + gravity_result = sim.run_gravity_simulation( + GravityDropRequest( + glb_path=pre_gravity_path, + max_convex_hull_num=32, + initial_height=gravity_initial_height, + ) + ) + gravity_transform = _matrix_from_json( + gravity_result.final_pose, + name=f"{obj_id}.gravity_final_pose", + ) + except Exception: + gravity_status = "failed" + gravity_reason = traceback.format_exc() + + # Apply gravity result (in internal Z-up space) + settled_scene = _copy_scene_with_transform( + normalized_scene, + gravity_transform, + ) + + # Center the bottom of the AABB at the XY origin. + settled_mesh = _scene_to_mesh(settled_scene, trimesh=trimesh) + settled_bounds = np.asarray(settled_mesh.bounds, dtype=np.float64) + settled_xy_center = _xy_aabb_center(settled_bounds) + settled_xy_size = _xy_aabb_size(settled_bounds) + settled_bottom_z = float(settled_bounds[0, 2]) + + centre_transform = np.eye(4, dtype=np.float64) + centre_transform[:3, 3] = [ + -float(settled_xy_center[0]), + -float(settled_xy_center[1]), + -settled_bottom_z, + ] + centred_scene = _copy_scene_with_transform( + settled_scene, + centre_transform, + ) + + # Verify final bounds + centred_mesh = _scene_to_mesh(centred_scene, trimesh=trimesh) + centred_bounds = np.asarray(centred_mesh.bounds, dtype=np.float64) + centred_xy_size = _xy_aabb_size(centred_bounds) + + # Export settled GLB (Z-up → Y-up for GLB output) + settled_glb_path = output_dir / f"{obj_id}_settled.glb" + _copy_scene_with_transform(centred_scene, z_to_y).export( + settled_glb_path + ) + + item = { + "id": obj_id, + "name": obj_name, + "status": "ok", + "gravity_status": gravity_status, + "gravity_reason": gravity_reason, + "scale_factor": scale_factor, + "settled_glb_path": relative_path( + str(settled_glb_path), + output_root, + ), + "settled_xy_size_m": centred_xy_size.tolist(), + "settled_xy_size_cm": (centred_xy_size * 100.0).tolist(), + "settled_bounds_m": centred_bounds.tolist(), + "mesh_z_height_m": mesh_z_height, + "bottom_to_xy_transform": bottom_to_xy.tolist(), + "gravity_transform": gravity_transform.tolist(), + "centre_transform": centre_transform.tolist(), + "composed_settle_transform": ( + centre_transform + @ gravity_transform + @ bottom_to_xy + @ scale_transform + @ y_to_z + ).tolist(), + } + settled_objects.append(item) + object_scenes.append((obj_id, centred_scene)) + + except Exception: + settled_objects.append( + { + "id": obj_id, + "name": obj_name, + "status": "failed", + "reason": traceback.format_exc(), + } + ) + + # Optimize the spatial layout. + layout_result = None + if object_scenes: + xy_sizes = { + oid: np.asarray( + _xy_aabb_size(_scene_to_mesh(scene, trimesh=trimesh).bounds), + dtype=np.float64, + ) + for oid, scene in object_scenes + } + relations = list(spatial_relations or []) + layout_result = _layout_text_objects_grid( + object_ids=[oid for oid, _ in object_scenes], + xy_sizes=xy_sizes, + spatial_relations=relations, + table_constraints=list(table_constraints or []), + ) + target_centers = layout_result["centers"] + initial_centers = layout_result.get("initial_centers", {}) + + # Render footprint layout diagnostics. + debug_dir = output_dir / "debug" + debug_dir.mkdir(parents=True, exist_ok=True) + debug_object_ids = [oid for oid, _ in object_scenes] + debug_before_centers = { + oid: np.zeros(2, dtype=np.float64) for oid in debug_object_ids + } + debug_renders = ( + ( + "footprint_layout_xy_before.png", + "Before Layout (all at origin)", + debug_before_centers, + ), + ( + "footprint_layout_xy_grid_init.png", + "After Grid Initialisation", + initial_centers, + ), + ( + "footprint_layout_xy_after.png", + "After SA Optimisation", + target_centers, + ), + ) + for filename, title, debug_centers in debug_renders: + try: + MatplotlibManager(figsize=(8, 8), dpi=180).render_footprint_layout( + RenderFootprintLayoutRequest( + object_ids=debug_object_ids, + centers=debug_centers, + xy_sizes=xy_sizes, + output_path=debug_dir / filename, + title=title, + ) + ) + except Exception as exc: + log_warning( + f"text clutter debug render failed file={filename} error={exc}" + ) + + # Apply layout positions to centred scenes + laid_out_scenes: list[tuple[str, Any]] = [] + for oid, scene in object_scenes: + target_xy = target_centers[oid] + settled_mesh = _scene_to_mesh(scene, trimesh=trimesh) + settled_bounds = np.asarray(settled_mesh.bounds, dtype=np.float64) + current_xy = _xy_aabb_center(settled_bounds) + placement = np.eye(4, dtype=np.float64) + placement[:3, 3] = [ + float(target_xy[0] - current_xy[0]), + float(target_xy[1] - current_xy[1]), + 0.0, + ] + laid_out_scene = _copy_scene_with_transform(scene, placement) + laid_out_scenes.append((oid, laid_out_scene)) + + # Export laid-out GLB (replaces the origin-centred one) + laid_out_glb_path = output_dir / f"{oid}_laid_out.glb" + _copy_scene_with_transform(laid_out_scene, z_to_y).export(laid_out_glb_path) + + # Update per-object metadata with layout position + for item in settled_objects: + if item.get("id") == oid: + item["layout_target_xy"] = target_xy.tolist() + item["layout_placement_transform"] = placement.tolist() + item["laid_out_glb_path"] = relative_path( + str(laid_out_glb_path), output_root + ) + laid_out_bounds = np.asarray( + _scene_to_mesh(laid_out_scene, trimesh=trimesh).bounds, + dtype=np.float64, + ) + item["laid_out_xy_size_cm"] = ( + _xy_aabb_size(laid_out_bounds) * 100.0 + ).tolist() + break + + object_scenes = laid_out_scenes + + clutter_2d_aabb_cm = _object_scenes_xy_aabb_manifest( + object_scenes=object_scenes, + trimesh=trimesh, + unit_scale=100.0, + unit="cm", + ) + + debug_manifest = { + "status": "ok", + "output_dir": relative_path(str(output_dir), output_root), + "object_count": len(objects), + "settled_count": len(object_scenes), + "clutter_2d_aabb_cm": clutter_2d_aabb_cm, + "debug_image_before_path": ( + relative_path( + str(debug_dir / "footprint_layout_xy_before.png"), + output_root, + ) + if object_scenes + else "" + ), + "debug_image_grid_init_path": ( + relative_path( + str(debug_dir / "footprint_layout_xy_grid_init.png"), + output_root, + ) + if object_scenes + else "" + ), + "debug_image_after_path": ( + relative_path( + str(debug_dir / "footprint_layout_xy_after.png"), + output_root, + ) + if object_scenes + else "" + ), + "layout_optimization": layout_result["metadata"] if layout_result else None, + "objects": settled_objects, + } + debug_manifest_path = output_dir / "debug" / "settle_diagnostics.json" + write_json(debug_manifest_path, debug_manifest) + + # Keep workflow state limited to the contract consumed by table fitting. + workflow_objects = [ + { + key: item[key] + for key in ( + "id", + "name", + "status", + "reason", + "settled_glb_path", + "laid_out_glb_path", + ) + if key in item + } + for item in settled_objects + ] + return { + "status": "ok", + "clutter_2d_aabb_cm": clutter_2d_aabb_cm, + "objects": workflow_objects, + "debug_manifest_path": relative_path(str(debug_manifest_path), output_root), + } + + +def _resolve_generated_path(value: Any, output_root: Path) -> Path: + path = Path(str(value or "")).expanduser() + if path.is_absolute(): + return path.resolve() + return (output_root / path).resolve() diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/servers/__init__.py b/embodichain/gen_sim/prompt2scene/agent_tools/servers/__init__.py new file mode 100644 index 000000000..e50272eff --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/servers/__init__.py @@ -0,0 +1,16 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- +"""External servers, ignored by git, for testing or demo purposes.""" \ No newline at end of file diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/tools/__init__.py b/embodichain/gen_sim/prompt2scene/agent_tools/tools/__init__.py new file mode 100644 index 000000000..015c41510 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/tools/__init__.py @@ -0,0 +1,19 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +__all__: list[str] = [] diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/tools/gym_export.py b/embodichain/gen_sim/prompt2scene/agent_tools/tools/gym_export.py new file mode 100644 index 000000000..9f3c638f5 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/tools/gym_export.py @@ -0,0 +1,319 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +import json +import math +import shutil +import time +from pathlib import Path +from typing import Any + +import numpy as np + +from embodichain.gen_sim.prompt2scene.workflows.artifact_writer import ( + STEP_RESULT_FILENAME, + UNIFIED_SCENE_GEN_STEP, +) + +__all__ = ["export_gym_config"] + +_DEFAULT_OBJECT_ATTRS: dict[str, Any] = { + "mass": 0.01, + "contact_offset": 0.003, + "rest_offset": 0.001, + "restitution": 0.01, + "max_depenetration_velocity": 10.0, + "min_position_iters": 32, + "min_velocity_iters": 8, +} + +_DEFAULT_TABLE_ATTRS: dict[str, Any] = { + "mass": 10.0, + "static_friction": 0.95, + "dynamic_friction": 0.9, + "restitution": 0.01, +} + +_DEFAULT_MAX_CONVEX_HULL_NUM = 8 + + +def _resolve_path(value: str, output_root: Path) -> Path: + path = Path(value).expanduser() + if path.is_absolute(): + return path.resolve() + return (output_root / path).resolve() + + +def _read_json(path: Path) -> dict[str, Any]: + with path.open("r", encoding="utf-8") as f: + data = json.load(f) + if not isinstance(data, dict): + raise ValueError(f"Expected JSON object at {path}") + return data + + +def _matrix_to_euler_xyz_deg(matrix: list[list[float]]) -> list[float]: + """Decompose a 3×3 or 4×4 rotation matrix into XYZ Euler angles (degrees).""" + m = np.asarray(matrix, dtype=np.float64) + r = m[:3, :3] + sy = math.sqrt(float(r[0, 0]) ** 2 + float(r[1, 0]) ** 2) + if sy > 1e-6: + x = math.atan2(float(r[2, 1]), float(r[2, 2])) + y = math.atan2(-float(r[2, 0]), sy) + z = math.atan2(float(r[1, 0]), float(r[0, 0])) + else: + x = math.atan2(-float(r[1, 2]), float(r[1, 1])) + y = math.atan2(-float(r[2, 0]), sy) + z = 0.0 + return [math.degrees(x), math.degrees(y), math.degrees(z)] + + +def _glb_aabb_bottom_center(glb_path: Path) -> list[float]: + """``[x, y, z]`` bottom-centre position in **simulation Z-up** space. + + The GLB is stored in Y-up convention (X=width, Y=up, Z=depth). + EmbodiChain simulation converts to Z-up on load, so we return the + position in Z-up space: ``center_X``, ``-center_Z``, ``min_Y``. + """ + import trimesh + + scene = trimesh.load(glb_path, force="scene") + if isinstance(scene, trimesh.Trimesh): + mesh = scene + else: + dumped = scene.dump(concatenate=True) + mesh = ( + dumped + if isinstance(dumped, trimesh.Trimesh) + else trimesh.util.concatenate( + [m for m in dumped if isinstance(m, trimesh.Trimesh)] + ) + ) + b = np.asarray(mesh.bounds, dtype=np.float64) + return [ + float(0.5 * (b[0, 0] + b[1, 0])), # centre X + float(-0.5 * (b[0, 2] + b[1, 2])), # -centre Z (GLB Z → internal -Y) + float(b[0, 1]), # min Y (GLB up → internal Z) + ] + + +def _glb_max_z(glb_path: Path) -> float: + """Maximum height (Y in GLB, Z in simulation) of a mesh.""" + import trimesh + + scene = trimesh.load(glb_path, force="scene") + if isinstance(scene, trimesh.Trimesh): + mesh = scene + else: + dumped = scene.dump(concatenate=True) + mesh = ( + dumped + if isinstance(dumped, trimesh.Trimesh) + else trimesh.util.concatenate( + [m for m in dumped if isinstance(m, trimesh.Trimesh)] + ) + ) + return float(np.asarray(mesh.bounds, dtype=np.float64)[1, 1]) # max Y + + +def export_gym_config( + output_root: Path, + *, + export_dir: Path | None = None, +) -> Path: + """Export the unified-scene-gen result as a gym_config.json bundle. + + Uses **simready** GLBs — transforms are written explicitly as + ``body_scale``, ``init_pos``, and ``init_rot``. + """ + output_root = output_root.expanduser().resolve() + if export_dir is None: + export_dir = output_root / "gym_export" + else: + export_dir = export_dir.expanduser().resolve() + export_dir.mkdir(parents=True, exist_ok=True) + + # ── step result & table-fit manifest ────────────────────────────── + step_result = _read_json( + output_root / UNIFIED_SCENE_GEN_STEP / STEP_RESULT_FILENAME + ) + table_fit = step_result.get("table_fit_to_clutter") or {} + manifest = _read_json( + _resolve_path(table_fit.get("manifest_path", ""), output_root) + ) + + # ── per-object metadata from simready→aligned manifest ──────────── + aligned_by_id: dict[str, dict[str, Any]] = {} + aligned_manifest_path = ( + output_root / UNIFIED_SCENE_GEN_STEP / "glb_gen" / "simready_to_aligned_manifest.json" + ) + if aligned_manifest_path.is_file(): + aligned_manifest = _read_json(aligned_manifest_path) + for item in aligned_manifest.get("items", []) or []: + if isinstance(item, dict): + aligned_by_id[str(item.get("id", ""))] = item + + # ── table surface Z (from fitted table GLB) ─────────────────────── + fitted_table_path = _resolve_path( + manifest.get("table_output_path", ""), output_root + ) + table_surface_z = ( + _glb_max_z(fitted_table_path) if fitted_table_path.is_file() else 0.0 + ) + + # ── description lookup ──────────────────────────────────────────── + object_meta_by_id: dict[str, dict[str, str]] = {} + for obj in step_result.get("objects", []) or []: + if isinstance(obj, dict): + oid = str(obj.get("id", "")) + if oid: + object_meta_by_id[oid] = { + "description": str(obj.get("description", "")).strip(), + "name": str(obj.get("name", "")).strip(), + } + + table_info = step_result.get("table") or {} + table_desc = str( + table_info.get("complete_table_description") + or table_info.get("description", "") + ).strip() + + mesh_assets_dir = export_dir / "mesh_assets" + mesh_assets_dir.mkdir(parents=True, exist_ok=True) + + # ── table ───────────────────────────────────────────────────────── + table_simready = _resolve_path( + table_info.get("simready_geometry_path") + or table_info.get("mesh_path", ""), + output_root, + ) + if not table_simready.is_file(): + raise FileNotFoundError(f"Table simready GLB not found: {table_simready}") + table_dst = mesh_assets_dir / "table" / "table_0.glb" + table_dst.parent.mkdir(parents=True, exist_ok=True) + shutil.copy2(table_simready, table_dst) + + uniform_scale = 1.0 + ts = manifest.get("table_xy_scale") + if isinstance(ts, dict): + uniform_scale = float(ts.get("uniform_scale", 1.0)) + + # ── objects ─────────────────────────────────────────────────────── + table_fit_objects = { + str(e["id"]): _resolve_path(e["path"], output_root) + for e in (manifest.get("objects") or []) + if isinstance(e, dict) + } + objects_info = step_result.get("objects") or [] + rigid_objects: list[dict[str, Any]] = [] + + def _obj_desc(obj: dict[str, Any]) -> str: + meta = object_meta_by_id.get(str(obj.get("id", ""))) + return (meta["description"] or meta["name"]) if meta else "" + + for obj in objects_info: + if not isinstance(obj, dict): + continue + object_id = str(obj.get("id", "")) + if not object_id: + continue + + # ── GLB: simready (normalised, no baked transforms) ────────── + source = obj.get("simready_geometry_path") or obj.get("mesh_path") + object_src = _resolve_path(source, output_root) + if not object_src.is_file(): + continue + + safe_name = object_id.replace("interact_", "").strip("_") or "object" + obj_dir = mesh_assets_dir / safe_name / object_id + obj_dir.mkdir(parents=True, exist_ok=True) + object_dst = obj_dir / f"{object_id}.glb" + shutil.copy2(object_src, object_dst) + + # ── body_scale ──────────────────────────────────────────────── + ms = obj.get("metric_scale") + scale_factor = float(ms.get("scale_factor", 1.0)) if isinstance(ms, dict) else 1.0 + body_scale = [scale_factor, scale_factor, scale_factor] + + # ── init_pos: read from fitted on-table GLB ─────────────────── + fitted_path = table_fit_objects.get(object_id) + if fitted_path and fitted_path.is_file(): + init_pos = _glb_aabb_bottom_center(fitted_path) + else: + init_pos = [0.0, 0.0, table_surface_z] + + # ── init_rot: decompose from simready→aligned rotation ──────── + init_rot: list[float] = [0.0, 0.0, 0.0] + aligned = aligned_by_id.get(object_id) + if aligned: + rot = aligned.get("rotation_matrix") + if rot and isinstance(rot, list): + init_rot = _matrix_to_euler_xyz_deg(rot) + + rigid_objects.append( + { + "uid": object_id, + "description": _obj_desc(obj), + "shape": { + "shape_type": "Mesh", + "fpath": str(object_dst.relative_to(export_dir)), + "compute_uv": False, + }, + "attrs": dict(_DEFAULT_OBJECT_ATTRS), + "body_type": "dynamic", + "init_pos": init_pos, + "init_rot": init_rot, + "body_scale": body_scale, + "max_convex_hull_num": _DEFAULT_MAX_CONVEX_HULL_NUM, + } + ) + + # ── write config ────────────────────────────────────────────────── + config = { + "id": f"Prompt2Scene-{int(time.time() * 1000)}-v0", + "max_episodes": 10, + "max_episode_steps": 300, + "env": {"events": {}, "observations": {}, "dataset": {}}, + "robot": {}, + "sensor": [], + "light": {}, + "background": [ + { + "uid": "table", + "description": table_desc, + "shape": { + "shape_type": "Mesh", + "fpath": str(table_dst.relative_to(export_dir)), + "compute_uv": False, + }, + "attrs": dict(_DEFAULT_TABLE_ATTRS), + "body_scale": [uniform_scale, uniform_scale, 1.0], + "body_type": "kinematic", + "init_pos": [0.0, 0.0, 0.0], + "init_rot": [0.0, 0.0, 0.0], + } + ], + "rigid_object": rigid_objects, + } + + config_path = export_dir / "gym_config.json" + config_path.write_text( + json.dumps(config, indent=4, ensure_ascii=False) + "\n", + encoding="utf-8", + ) + return config_path diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/tools/image_scene_asset_generation.py b/embodichain/gen_sim/prompt2scene/agent_tools/tools/image_scene_asset_generation.py new file mode 100644 index 000000000..2275c40fa --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/tools/image_scene_asset_generation.py @@ -0,0 +1,636 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + + +from __future__ import annotations + +import shutil +import traceback +from pathlib import Path +from typing import Any + +import numpy as np + +from embodichain.gen_sim.prompt2scene.utils.log import log_info, log_warning +from embodichain.gen_sim.prompt2scene.agent_tools.clients.image_segmentation_client import ( + decode_rle_mask, +) +from embodichain.gen_sim.prompt2scene.agent_tools.managers.geometry_generation_manager import ( + GeometryGenerationManager, + RgbaImageToGeometryRequest, + RgbaImagesToGeometriesRequest, +) +from embodichain.gen_sim.prompt2scene.agent_tools.managers.image_generation_manager import ( + ImageGenerationManager, + TextToAssetImageRequest, +) +from embodichain.gen_sim.prompt2scene.agent_tools.managers.image_segmentation_manager import ( + AssetImageToRgbaRequest, + ImageSegmentationManager, +) +from embodichain.gen_sim.prompt2scene.agent_tools.managers.image_scene_manager import ( + _export_support_aligned_layout_glbs, +) +from embodichain.gen_sim.prompt2scene.agent_tools.managers.simready_manager import ( + MakeAssetSimreadyRequest, + MakeTableSimreadyRequest, + SimreadyManager, +) +from embodichain.gen_sim.prompt2scene.agent_tools.managers.metric_scale_manager import ( + METRIC_SCALE_ENABLED, + EstimateMetricScalesRequest, + MetricScaleManager, + MetricScaleObjectInput, +) +from embodichain.gen_sim.prompt2scene.agent_tools.managers.geometry_manager.scene_geometry import ( + _compose_sam3d_multi_object_transform, +) +from embodichain.gen_sim.prompt2scene.agent_tools.managers.image_scene_manager import ( + _write_multi_object_layout_manifests, +) +from embodichain.gen_sim.prompt2scene.agent_tools.managers.image_scene_manager.prompts import ( + build_image_metric_scale_messages, +) +from embodichain.gen_sim.prompt2scene.agent_tools.managers.image_scene_manager.schemas import ( + IMAGE_METRIC_SCALE_JSON_SCHEMA, +) +from embodichain.gen_sim.prompt2scene.utils.io import ( + relative_path, +) + +__all__ = ["generate_image_scene_assets"] + +UNIFIED_SCENE_STEP = "unified_scene" + + +def generate_image_scene_assets( + object_specs: list[dict[str, Any]], + table_spec: dict[str, Any], + spatial_relations: list[dict[str, Any]], + segments_data: dict[str, Any], + image_gen_dir: Path, + glb_gen_dir: Path, + debug_dir: Path, + output_root: Path, + llm: Any | None = None, +) -> dict[str, Any]: + """Run layout-aware table/support and object generation from image masks.""" + log_info(f"image object layout generation started count={len(object_specs)}") + status = "ok" + failure_reason = "" + original_image_path = str(segments_data.get("image_path", "")) + segment_by_id: dict[str, dict[str, Any]] = { + str(seg["asset_id"]): seg + for seg in segments_data.get("asset_segments", []) + if seg.get("asset_id") + } + table_segment = segments_data.get("table_segment") + if not isinstance(table_segment, dict): + table_segment = None + debug_subdir = debug_dir / "multi_object_masks" + masks_dir = debug_subdir / "masks" + raw_download_dir = glb_gen_dir / "raw_downloads" + simready_dir = glb_gen_dir / "multi_object_layouts_simready" + aligned_dir = glb_gen_dir / "multi_object_layouts_aligned" + masks_dir.mkdir(parents=True, exist_ok=True) + raw_download_dir.mkdir(parents=True, exist_ok=True) + simready_dir.mkdir(parents=True, exist_ok=True) + aligned_dir.mkdir(parents=True, exist_ok=True) + + requested_items: list[dict[str, Any]] = [] + mask_paths: list[Path] = [] + + table_id = str(table_spec.get("id", "table")).strip() or "table" + table_name = str(table_spec.get("name", "table")).strip() or "table" + is_complete_visible_table = bool( + table_spec.get("is_complete_visible_table", False) + ) + skipped_table: dict[str, Any] | None = None + if table_segment is None: + skipped_table = { + "id": table_id, + "name": table_name, + "reason": "missing_table_segment", + } + else: + table_mask_rle = table_segment.get("mask_rle") + if table_mask_rle is None: + skipped_table = { + "id": table_id, + "name": table_name, + "reason": "missing_table_mask_rle", + } + else: + mask_path = masks_dir / f"{len(requested_items):04d}_{table_id}_mask.png" + decode_rle_mask(table_mask_rle).save(mask_path) + mask_paths.append(mask_path) + requested_items.append( + { + "id": table_id, + "name": table_name, + "kind": "table", + "mask_path": str(mask_path), + } + ) + + for obj_spec in object_specs: + obj_id = str(obj_spec.get("id", "")).strip() + obj_name = str(obj_spec.get("name", "")).strip() + if not obj_id: + continue + segment = segment_by_id.get(obj_id) + if segment is None: + continue + mask_rle = segment.get("mask_rle") + if mask_rle is None: + continue + + mask_path = masks_dir / f"{len(requested_items):04d}_{obj_id}_mask.png" + decode_rle_mask(mask_rle).save(mask_path) + mask_paths.append(mask_path) + requested_items.append( + { + "id": obj_id, + "name": obj_name, + "description": str(obj_spec.get("description", "")), + "kind": "object", + "mask_path": str(mask_path), + } + ) + + generated_objects: list[dict[str, Any]] = [] + generated_table: dict[str, Any] | None = None + image_manager = ImageGenerationManager() + segmentation_manager = ImageSegmentationManager() + geometry_manager = GeometryGenerationManager() + simready_manager = SimreadyManager() + try: + if skipped_table is not None: + raise ValueError( + "No valid table/support mask found for image multi-object " + f"layout generation: {skipped_table['reason']}" + ) + if not mask_paths: + raise ValueError( + "No valid masks found for image multi-object layout generation." + ) + + result = geometry_manager.convert_rgba_images_to_geometries( + RgbaImagesToGeometriesRequest( + image_path=Path(original_image_path), + mask_paths=mask_paths, + output_dir=raw_download_dir, + ) + ) + if len(result.objects) != len(requested_items): + raise RuntimeError( + "Multi-object SAM3D result count mismatch: " + f"requested {len(requested_items)}, got {len(result.objects)}" + ) + for requested, generated in zip(requested_items, result.objects): + expected_sam3d_name = Path(requested["mask_path"]).stem + if generated.name != expected_sam3d_name: + raise RuntimeError( + "Multi-object SAM3D result order mismatch: " + f"expected {expected_sam3d_name!r}, got {generated.name!r}" + ) + downloaded_raw_path = Path(generated.geometry_path).expanduser().resolve() + raw_geometry_path = str(downloaded_raw_path) + status_parts: list[str] = [] + transform_matrix: list[list[float]] = [] + try: + transform = _compose_sam3d_multi_object_transform( + rotation_quaternion_wxyz=generated.rotation_quaternion_wxyz, + translation=generated.translation, + scale=generated.scale, + ) + transform_matrix = transform.tolist() + except Exception: + status_parts.append( + f"transform_matrix_failed: {traceback.format_exc()}" + ) + + simready_geometry_path = "" + raw_to_simready_glb_matrix: list[list[float]] = [] + metric_scale: dict[str, Any] | None = None + try: + if requested["kind"] == "table": + if is_complete_visible_table: + table_result = simready_manager.make_table_simready( + MakeTableSimreadyRequest( + input_path=Path(raw_geometry_path), + output_path=simready_dir + / f"{requested['id']}_simready.glb", + ) + ) + simready_geometry_path = str(table_result.output_path) + raw_to_simready_glb_matrix = table_result.transform_matrix + else: + asset_result = simready_manager.make_asset_simready( + MakeAssetSimreadyRequest( + input_path=Path(raw_geometry_path), + output_path=simready_dir + / f"{requested['id']}_simready.glb", + ) + ) + simready_geometry_path = str(asset_result.output_path) + raw_to_simready_glb_matrix = asset_result.transform_matrix + except Exception: + status_parts.append(f"simready_failed: {traceback.format_exc()}") + item_status = "ok" if not status_parts else "; ".join(status_parts) + generated_item = { + "id": requested["id"], + "name": requested["name"], + "kind": requested["kind"], + "description": str(table_spec.get("description", "")) + if requested["kind"] == "table" + else str(requested.get("description", "")), + "complete_table_description": str( + table_spec.get("complete_table_description") + or table_spec.get("description", "") + ).strip() + if requested["kind"] == "table" + else "", + "is_complete_visible_table": is_complete_visible_table + if requested["kind"] == "table" + else False, + "status": item_status, + "mask_path": relative_path(requested["mask_path"], output_root), + "raw_geometry_path": relative_path(raw_geometry_path, output_root), + "simready_geometry_path": relative_path( + simready_geometry_path, output_root + ) + if simready_geometry_path + else "", + "mesh_path": relative_path(simready_geometry_path, output_root) + if simready_geometry_path + else "", + "sam3d_name": generated.name, + "downloaded_raw_geometry_path": relative_path( + str(downloaded_raw_path), output_root + ), + "rotation_quaternion_wxyz": generated.rotation_quaternion_wxyz, + "translation": generated.translation, + "scale": generated.scale, + "transform_matrix": transform_matrix, + "raw_to_simready_glb_matrix": raw_to_simready_glb_matrix, + "metric_scale": metric_scale, + } + if requested["kind"] == "table": + support_reference_path = raw_download_dir / "support_surface_raw.glb" + table_raw_path = raw_download_dir / "table_raw.glb" + shutil.copy2(downloaded_raw_path, support_reference_path) + if is_complete_visible_table: + shutil.copy2(downloaded_raw_path, table_raw_path) + generated_item["raw_geometry_path"] = relative_path( + str(table_raw_path), + output_root, + ) + generated_item["support_reference_geometry_path"] = relative_path( + str(support_reference_path), + output_root, + ) + generated_item["support_reference_transform_matrix"] = transform_matrix + generated_item["support_normal_source"] = "segmented_table" + generated_item["table_asset_source"] = "segmented_table" + if not is_complete_visible_table: + # Replace partial image table with description-generated table. + incomplete_table_id = str( + generated_item.get("id") + or table_spec.get("id") + or "table" + ) + incomplete_table_desc = str( + table_spec.get("complete_table_description") + or table_spec.get("description", "") + ).strip() + incomplete_debug_dir = ( + debug_dir / incomplete_table_id / "description_generated" + ) + incomplete_debug_dir.mkdir(parents=True, exist_ok=True) + incomplete_raw_download_dir = glb_gen_dir / "raw_downloads" + incomplete_raw_download_dir.mkdir(parents=True, exist_ok=True) + incomplete_raw_image = str( + image_manager.generate_asset_image_from_text( + TextToAssetImageRequest( + prompt=incomplete_table_desc, + output_path=incomplete_debug_dir + / f"{incomplete_table_id}_complete.png", + ) + ) + ) + incomplete_rgba = str( + segmentation_manager.convert_asset_image_to_rgba( + AssetImageToRgbaRequest( + image_path=Path(incomplete_raw_image), + prompt=incomplete_table_desc + if incomplete_table_desc.strip() + else "whole table", + output_path=image_gen_dir + / f"{incomplete_table_id}_complete.png", + ) + ) + ) + incomplete_raw_glb = str( + geometry_manager.convert_rgba_image_to_geometry( + RgbaImageToGeometryRequest( + image_path=Path(incomplete_rgba), + output_path=incomplete_debug_dir + / f"{incomplete_table_id}_complete_raw.glb", + ) + ) + ) + incomplete_table_raw_path = ( + incomplete_raw_download_dir / "table_raw.glb" + ) + shutil.copy2(incomplete_raw_glb, incomplete_table_raw_path) + incomplete_simready = simready_manager.make_table_simready( + MakeTableSimreadyRequest( + input_path=incomplete_table_raw_path, + output_path=glb_gen_dir + / "multi_object_layouts_simready" + / f"{incomplete_table_id}_simready.glb", + ) + ) + generated_item.update( + { + "image_path": relative_path( + incomplete_rgba, output_root + ), + "raw_geometry_path": relative_path( + str(incomplete_table_raw_path), output_root + ), + "generated_table_raw_geometry_path": relative_path( + incomplete_raw_glb, output_root + ), + "simready_geometry_path": relative_path( + str(incomplete_simready.output_path), + output_root, + ), + "mesh_path": relative_path( + str(incomplete_simready.output_path), + output_root, + ), + "raw_to_simready_glb_matrix": ( + incomplete_simready.transform_matrix + ), + "transform_matrix": np.eye( + 4, dtype=np.float64 + ).tolist(), + "table_asset_source": "description_generated", + "complete_table_description": incomplete_table_desc, + } + ) + generated_table = generated_item + else: + generated_objects.append(generated_item) + except Exception as exc: + status = "failed" + failure_reason = traceback.format_exc() + log_warning(f"image object geometry generation failed error={exc}") + + if generated_objects: + _estimate_image_scene_metric_scales( + objects=generated_objects, + bbox_name_image_path=segments_data.get("bbox_name_image_path"), + output_dir=glb_gen_dir, + output_root=output_root, + llm=llm, + ) + + alignment_result: dict[str, Any] | None = None + if generated_table is not None and generated_objects: + try: + alignment_result = _export_support_aligned_layout_glbs( + table=generated_table, + objects=generated_objects, + spatial_relations=spatial_relations, + original_image_path=Path(original_image_path) + if original_image_path + else None, + llm=llm, + output_dir=aligned_dir, + output_root=output_root, + ) + aligned_object_by_id = { + item["id"]: item for item in alignment_result["objects"] + } + for generated_object in generated_objects: + aligned_object = aligned_object_by_id.get(generated_object["id"]) + if aligned_object is not None: + generated_object["aligned_geometry_path"] = aligned_object[ + "aligned_geometry_path" + ] + except Exception as exc: + status = "failed" + failure_reason = traceback.format_exc() + log_warning(f"image object alignment failed error={exc}") + alignment_result = { + "status": "failed", + "reason": failure_reason, + } + + manifest_paths = _write_multi_object_layout_manifests( + glb_gen_dir=glb_gen_dir, + output_root=output_root, + table=generated_table, + objects=generated_objects, + alignment=alignment_result, + ) + table_fields = ( + "id", + "name", + "status", + "is_complete_visible_table", + "complete_table_description", + "table_asset_source", + "support_normal_source", + "image_path", + "raw_geometry_path", + "support_reference_geometry_path", + "generated_table_raw_geometry_path", + "transformed_geometry_path", + "simready_geometry_path", + "aligned_geometry_path", + "mesh_path", + ) + object_fields = ( + "id", + "name", + "status", + "image_path", + "mesh_path", + "aligned_geometry_path", + "metric_scale", + ) + workflow_table = ( + {key: generated_table[key] for key in table_fields if key in generated_table} + if generated_table is not None + else None + ) + workflow_objects = [ + {key: item[key] for key in object_fields if key in item} + for item in generated_objects + ] + if workflow_table is not None and workflow_table.get("status") != "ok": + workflow_table["status"] = "failed" + for item in workflow_objects: + if item.get("status") != "ok": + item["status"] = "failed" + workflow_alignment = ( + { + key: alignment_result[key] + for key in ("status", "final_clutter_2d_aabb_cm") + if key in alignment_result + } + if alignment_result is not None + else None + ) + result = { + "status": status, + "table": workflow_table, + "objects": workflow_objects, + "alignment": workflow_alignment, + "manifests": manifest_paths, + } + if failure_reason: + result["reason"] = failure_reason + log_info( + "image object layout generation completed " + f"status={status} generated={len(generated_objects)}" + ) + return result + + +def _estimate_image_scene_metric_scales( + *, + objects: list[dict[str, Any]], + bbox_name_image_path: Any, + output_dir: Path, + output_root: Path, + llm: Any | None, +) -> dict[str, Any]: + result: dict[str, Any] = { + "status": "skipped", + "method": "image_scene_bbox_name_vlm_candidate_shape_ratio_median_scale", + "bbox_name_image_path": str(bbox_name_image_path or ""), + "objects": [], + } + try: + if not METRIC_SCALE_ENABLED: + result["reason"] = "metric_scale_disabled" + MetricScaleManager.set_for_all_objects( + objects=objects, + status="skipped", + reason="metric_scale_disabled", + method=str(result["method"]), + ) + return result + if llm is None: + result["reason"] = "missing_llm" + MetricScaleManager.set_for_all_objects( + objects=objects, + status="skipped", + reason="missing_llm", + method=str(result["method"]), + ) + return result + + bbox_image = _resolve_generated_path(bbox_name_image_path, output_root) + if not bbox_image.is_file(): + result["reason"] = "missing_bbox_name_image" + MetricScaleManager.set_for_all_objects( + objects=objects, + status="skipped", + reason="missing_bbox_name_image", + method=str(result["method"]), + ) + return result + + metric_objects = _build_metric_scale_inputs( + objects=objects, + output_root=output_root, + ) + result["objects"] = MetricScaleManager.object_prompt_payload(metric_objects) + metric_result = MetricScaleManager.estimate_metric_scales( + EstimateMetricScalesRequest( + objects=metric_objects, + messages=build_image_metric_scale_messages( + bbox_name_image_path=bbox_image, + objects_json=result["objects"], + ), + schema=IMAGE_METRIC_SCALE_JSON_SCHEMA, + llm=llm, + context="Image scene metric scale estimate", + method=str(result["method"]), + step_name=UNIFIED_SCENE_STEP, + raw_output_path=output_dir / "image_metric_scale_raw_model_output.json", + ) + ) + estimates = metric_result.object_scales + MetricScaleManager.apply_to_objects(objects=objects, object_scales=estimates) + result.update( + { + "status": "ok", + "object_scales": estimates, + "unit_note": ( + "Per-object scale_factor is not baked into simready GLBs. " + "Image alignment later computes one clamped global clutter " + "scale from these per-object estimates, on top of SAM3D " + "per-object layout scale." + ), + } + ) + except Exception: + result.update({"status": "failed", "reason": traceback.format_exc()}) + MetricScaleManager.set_for_all_objects( + objects=objects, + status="failed", + reason="image_scene_metric_scale_failed", + method=str(result["method"]), + ) + return result + + +def _build_metric_scale_inputs( + *, + objects: list[dict[str, Any]], + output_root: Path, +) -> list[MetricScaleObjectInput]: + inputs: list[MetricScaleObjectInput] = [] + for obj in objects: + mesh_path = _resolve_generated_path( + obj.get("simready_geometry_path") or obj.get("mesh_path"), + output_root, + ) + if not mesh_path.is_file(): + raise FileNotFoundError(f"Simready object GLB not found: {mesh_path}") + inputs.append( + MetricScaleObjectInput( + object_id=str(obj.get("id", "")), + object_name=str(obj.get("name", "")), + object_description=str(obj.get("description", "")), + mesh_path=mesh_path, + ) + ) + return inputs + + +def _resolve_generated_path(value: Any, output_root: Path) -> Path: + path = Path(str(value or "")).expanduser() + if path.is_absolute(): + return path.resolve() + return (output_root / path).resolve() diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/tools/table_fit_scene.py b/embodichain/gen_sim/prompt2scene/agent_tools/tools/table_fit_scene.py new file mode 100644 index 000000000..ae96b3a39 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/tools/table_fit_scene.py @@ -0,0 +1,105 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +import traceback +from pathlib import Path +from typing import Any + +from embodichain.gen_sim.prompt2scene.agent_tools.managers.table_clutter_fit_manager import ( + fit_table_to_clutter, +) +from embodichain.gen_sim.prompt2scene.utils.log import log_info, log_warning + +__all__ = ["fit_image_scene_table", "fit_text_scene_table"] + + +def fit_text_scene_table( + *, + table_result: dict[str, Any], + clutter_layout_result: dict[str, Any], + output_root: Path, + output_dir: Path, +) -> dict[str, Any]: + """Fit the text-scene table and convert failures to result data.""" + try: + result = fit_table_to_clutter( + table_result=table_result, + clutter_result=clutter_layout_result, + output_root=output_root, + output_dir=output_dir, + ) + log_info(f"text table fit completed status={result.get('status')}") + return result + except Exception as exc: + log_warning(f"text table fit failed error={exc}") + return { + "status": "failed", + "reason": traceback.format_exc(), + } + + +def fit_image_scene_table( + *, + layout_result: dict[str, Any], + fallback_table_result: dict[str, Any] | None, + output_root: Path, + output_dir: Path, +) -> dict[str, Any]: + """Fit the image-scene table or return a structured skipped result.""" + generated_table = layout_result.get("table") or fallback_table_result + generated_objects = layout_result.get("objects") or [] + alignment_result = layout_result.get("alignment") + if ( + generated_table is None + or not generated_objects + or not isinstance(alignment_result, dict) + ): + return { + "status": "skipped", + "reason": "missing_table_objects_or_alignment", + } + + try: + clutter_result = { + "clutter_2d_aabb_cm": alignment_result.get( + "final_clutter_2d_aabb_cm" + ), + "objects": [ + { + "id": item["id"], + "status": "ok", + "laid_out_glb_path": item["aligned_geometry_path"], + } + for item in generated_objects + if item.get("id") and item.get("aligned_geometry_path") + ], + } + result = fit_table_to_clutter( + table_result=generated_table, + clutter_result=clutter_result, + output_root=output_root, + output_dir=output_dir, + ) + log_info(f"image table fit completed status={result.get('status')}") + return result + except Exception as exc: + log_warning(f"image table fit failed error={exc}") + return { + "status": "failed", + "reason": traceback.format_exc(), + } diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/tools/text_asset_generation.py b/embodichain/gen_sim/prompt2scene/agent_tools/tools/text_asset_generation.py new file mode 100644 index 000000000..1beb76039 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/tools/text_asset_generation.py @@ -0,0 +1,294 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +import shutil +import traceback +from pathlib import Path +from typing import Any + +from embodichain.gen_sim.prompt2scene.agent_tools.managers.geometry_generation_manager import ( + GeometryGenerationManager, + RgbaImageToGeometryRequest, +) +from embodichain.gen_sim.prompt2scene.agent_tools.managers.image_generation_manager import ( + ImageGenerationManager, + TextToAssetImageRequest, +) +from embodichain.gen_sim.prompt2scene.agent_tools.managers.image_segmentation_manager import ( + AssetImageToRgbaRequest, + ImageSegmentationManager, +) +from embodichain.gen_sim.prompt2scene.agent_tools.managers.simready_manager import ( + MakeAssetSimreadyRequest, + MakeTableSimreadyRequest, + SimreadyManager, +) +from embodichain.gen_sim.prompt2scene.utils.log import log_info, log_warning + +__all__ = [ + "generate_text_object_asset", + "generate_text_object_assets", + "generate_text_table_asset", +] + + +def generate_text_object_asset( + *, + object_spec: dict[str, Any], + image_gen_dir: Path, + glb_gen_dir: Path, + debug_dir: Path, +) -> dict[str, Any]: + """Generate one object asset from a text-origin object spec.""" + object_id = str(object_spec.get("id", "object")) + object_name = str(object_spec.get("name", "")) + description = str(object_spec.get("description", "")) + class_candidates = [ + str(candidate).replace("_", " ") + for candidate in object_spec.get("class_candidate", []) + if isinstance(candidate, str) and candidate.strip() + ] + status = "ok" + image_path = "" + raw_geometry_path = "" + mesh_path = "" + raw_to_simready_matrix: list[list[float]] = [] + + debug_subdir = debug_dir / object_id + debug_subdir.mkdir(parents=True, exist_ok=True) + log_info(f"text object generation started id={object_id} name={object_name}") + + image_manager = ImageGenerationManager() + segmentation_manager = ImageSegmentationManager() + geometry_manager = GeometryGenerationManager() + simready_manager = SimreadyManager() + + try: + image_prompt = f"{object_name}, {description}".strip(", ") + raw_image_path = str( + image_manager.generate_asset_image_from_text( + TextToAssetImageRequest( + prompt=image_prompt, + output_path=debug_subdir / f"{object_id}.png", + ) + ) + ) + + rgba_prompts: list[str] = [] + if description.strip(): + rgba_prompts.append(description.strip()) + for candidate in class_candidates: + candidate_prompt = f"The entire {candidate} on the center of the image" + if candidate_prompt not in rgba_prompts: + rgba_prompts.append(candidate_prompt) + if not rgba_prompts: + rgba_prompts.append( + f"the entire single isolated object {object_name}" + if object_name + else "the entire single isolated object" + ) + + rgba_path = "" + last_rgba_error: Exception | None = None + for prompt in rgba_prompts: + try: + rgba_path = str( + segmentation_manager.convert_asset_image_to_rgba( + AssetImageToRgbaRequest( + image_path=Path(raw_image_path), + prompt=prompt, + output_path=image_gen_dir / f"{object_id}.png", + ) + ) + ) + break + except Exception as exc: + last_rgba_error = exc + log_warning( + "text object segmentation prompt failed " + f"id={object_id} prompt={prompt!r} error={exc}" + ) + if not rgba_path: + raise last_rgba_error or RuntimeError( + f"No RGBA prompt succeeded for {object_id}" + ) + + raw_glb_path = str( + geometry_manager.convert_rgba_image_to_geometry( + RgbaImageToGeometryRequest( + image_path=Path(rgba_path), + output_path=debug_subdir / f"{object_id}_raw.glb", + ) + ) + ) + raw_geometry_dir = glb_gen_dir / "raw_downloads" + raw_geometry_dir.mkdir(parents=True, exist_ok=True) + object_raw_path = raw_geometry_dir / f"{object_id}_raw.glb" + shutil.copy2(raw_glb_path, object_raw_path) + raw_geometry_path = str(object_raw_path) + + simready_result = simready_manager.make_asset_simready( + MakeAssetSimreadyRequest( + input_path=Path(raw_glb_path), + output_path=glb_gen_dir + / "text_objects_simready" + / f"{object_id}_simready.glb", + ) + ) + mesh_path = str(simready_result.output_path) + raw_to_simready_matrix = simready_result.transform_matrix + + image_path = rgba_path + log_info(f"text object generation completed id={object_id} mesh={mesh_path}") + except Exception as exc: + status = f"failed: {traceback.format_exc()}" + log_warning(f"text object generation failed id={object_id} error={exc}") + + return { + "id": object_id, + "name": object_name, + "status": status, + "image_path": image_path, + "raw_geometry_path": raw_geometry_path, + "mesh_path": mesh_path, + "simready_geometry_path": mesh_path, + "raw_to_simready_glb_matrix": raw_to_simready_matrix, + "metric_scale": None, + } + + +def generate_text_object_assets( + *, + object_specs: list[dict[str, Any]], + image_gen_dir: Path, + glb_gen_dir: Path, + debug_dir: Path, +) -> list[dict[str, Any]]: + """Generate all object assets for a text-origin unified scene.""" + log_info(f"text object batch generation started count={len(object_specs)}") + results = [ + generate_text_object_asset( + object_spec=object_spec, + image_gen_dir=image_gen_dir, + glb_gen_dir=glb_gen_dir, + debug_dir=debug_dir, + ) + for object_spec in object_specs + ] + succeeded = sum(result.get("status") == "ok" for result in results) + log_info( + f"text object batch generation completed " + f"succeeded={succeeded} failed={len(results) - succeeded}" + ) + return results + + +def generate_text_table_asset( + *, + table_spec: dict[str, Any], + image_gen_dir: Path, + glb_gen_dir: Path, + debug_dir: Path, +) -> dict[str, Any]: + """Generate the table asset for a text-origin unified scene.""" + table_id = str(table_spec.get("id", "table")) + description = str( + table_spec.get("complete_table_description") + or table_spec.get("description", "") + ).strip() + status = "ok" + image_path = "" + raw_geometry_path = "" + generated_table_raw_geometry_path = "" + mesh_path = "" + + debug_subdir = debug_dir / table_id + debug_subdir.mkdir(parents=True, exist_ok=True) + log_info(f"text table generation started id={table_id}") + + image_manager = ImageGenerationManager() + segmentation_manager = ImageSegmentationManager() + geometry_manager = GeometryGenerationManager() + simready_manager = SimreadyManager() + + try: + raw_image_path = str( + image_manager.generate_asset_image_from_text( + TextToAssetImageRequest( + prompt=description, + output_path=debug_subdir / f"{table_id}.png", + ) + ) + ) + rgba_path = str( + segmentation_manager.convert_asset_image_to_rgba( + AssetImageToRgbaRequest( + image_path=Path(raw_image_path), + prompt=description if description.strip() else "whole table", + output_path=image_gen_dir / f"{table_id}.png", + ) + ) + ) + raw_glb_path = str( + geometry_manager.convert_rgba_image_to_geometry( + RgbaImageToGeometryRequest( + image_path=Path(rgba_path), + output_path=debug_subdir / f"{table_id}_raw.glb", + ) + ) + ) + generated_table_raw_geometry_path = raw_glb_path + raw_geometry_dir = glb_gen_dir / "raw_downloads" + raw_geometry_dir.mkdir(parents=True, exist_ok=True) + table_raw_path = raw_geometry_dir / "table_raw.glb" + shutil.copy2(raw_glb_path, table_raw_path) + raw_geometry_path = str(table_raw_path) + mesh_path = str( + simready_manager.make_table_simready( + MakeTableSimreadyRequest( + input_path=Path(raw_geometry_path), + output_path=glb_gen_dir + / "text_objects_simready" + / f"{table_id}_simready.glb", + ) + ).output_path + ) + image_path = rgba_path + log_info(f"text table generation completed id={table_id} mesh={mesh_path}") + except Exception as exc: + status = f"failed: {traceback.format_exc()}" + log_warning(f"text table generation failed id={table_id} error={exc}") + + return { + "id": table_id, + "name": str(table_spec.get("name", "table")), + "description": str(table_spec.get("description", "")), + "complete_table_description": description, + "is_complete_visible_table": bool( + table_spec.get("is_complete_visible_table", False) + ), + "status": status, + "image_path": image_path, + "raw_geometry_path": raw_geometry_path, + "generated_table_raw_geometry_path": generated_table_raw_geometry_path, + "support_reference_geometry_path": "", + "table_asset_source": "description_generated", + "support_normal_source": "", + "mesh_path": mesh_path, + "simready_geometry_path": mesh_path, + } diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/tools/text_clutter_layout.py b/embodichain/gen_sim/prompt2scene/agent_tools/tools/text_clutter_layout.py new file mode 100644 index 000000000..80bc32100 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/tools/text_clutter_layout.py @@ -0,0 +1,62 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +import traceback +from pathlib import Path +from typing import Any + +from embodichain.gen_sim.prompt2scene.utils.log import log_info, log_warning +from embodichain.gen_sim.prompt2scene.agent_tools.managers.text_layout_manager import ( + settle_text_objects_to_ground, +) + +__all__ = ["generate_text_clutter_layout"] + + +def generate_text_clutter_layout( + *, + object_results: list[dict[str, Any]], + spatial_relations: list[dict[str, Any]], + table_constraints: list[dict[str, Any]], + output_dir: Path, + output_root: Path, +) -> dict[str, Any]: + """Settle and spatially arrange generated text-scene objects.""" + if not object_results: + return { + "status": "skipped", + "reason": "no_text_objects", + } + + try: + log_info(f"text clutter layout started count={len(object_results)}") + result = settle_text_objects_to_ground( + objects=object_results, + spatial_relations=spatial_relations, + table_constraints=table_constraints, + output_dir=output_dir, + output_root=output_root, + ) + log_info(f"text clutter layout completed status={result.get('status')}") + return result + except Exception as exc: + log_warning(f"text clutter layout failed error={exc}") + return { + "status": "failed", + "reason": traceback.format_exc(), + } diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/tools/text_scene_metric_scale.py b/embodichain/gen_sim/prompt2scene/agent_tools/tools/text_scene_metric_scale.py new file mode 100644 index 000000000..fd0b13835 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/tools/text_scene_metric_scale.py @@ -0,0 +1,161 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +import traceback +from pathlib import Path +from typing import Any + +from embodichain.gen_sim.prompt2scene.agent_tools.managers.metric_scale_manager import ( + METRIC_SCALE_ENABLED, + EstimateMetricScalesRequest, + MetricScaleManager, + MetricScaleObjectInput, +) +from embodichain.gen_sim.prompt2scene.utils.io import write_json +from embodichain.gen_sim.prompt2scene.utils.log import log_info, log_warning + +__all__ = ["build_metric_scale_inputs", "estimate_text_scene_metric_scale"] + + +def estimate_text_scene_metric_scale( + *, + object_results: list[dict[str, Any]], + user_text: str, + messages: list[dict[str, Any]], + schema: dict[str, Any], + output_dir: Path, + output_root: Path, + llm: Any | None, + step_name: str, +) -> dict[str, Any]: + """Estimate real-world scales for generated text-scene objects.""" + result: dict[str, Any] = { + "status": "skipped", + "method": "text_scene_vlm_candidate_shape_ratio_median_scale", + "user_text": user_text, + "objects": [], + } + try: + if not object_results: + result["reason"] = "missing_objects" + log_warning("text scene metric scale skipped reason=missing_objects") + return result + if not METRIC_SCALE_ENABLED: + result["reason"] = "metric_scale_disabled" + MetricScaleManager.set_for_all_objects( + objects=object_results, + status="skipped", + reason="metric_scale_disabled", + method=str(result["method"]), + ) + log_info("text scene metric scale skipped reason=metric_scale_disabled") + return result + if llm is None: + result["reason"] = "missing_llm" + MetricScaleManager.set_for_all_objects( + objects=object_results, + status="skipped", + reason="missing_llm", + method=str(result["method"]), + ) + log_warning("text scene metric scale skipped reason=missing_llm") + return result + + log_info(f"text scene metric scale started count={len(object_results)}") + metric_objects = build_metric_scale_inputs( + objects=object_results, + output_root=output_root, + ) + result["objects"] = MetricScaleManager.object_prompt_payload(metric_objects) + metric_result = MetricScaleManager.estimate_metric_scales( + EstimateMetricScalesRequest( + objects=metric_objects, + messages=messages, + schema=schema, + llm=llm, + context="Text scene metric scale estimate", + method=str(result["method"]), + step_name=step_name, + raw_output_path=output_dir / "raw_model_output.json", + ) + ) + raw_model_output = metric_result.raw_model_output or {} + if not (output_dir / "raw_model_output.json").is_file(): + try: + write_json(output_dir / "raw_model_output.json", raw_model_output) + except Exception as exc: + log_warning(f"metric scale raw output write failed error={exc}") + + estimates = metric_result.object_scales + MetricScaleManager.apply_to_objects( + objects=object_results, + object_scales=estimates, + ) + result.update( + { + "status": "ok", + "object_scales": estimates, + "unit_note": ( + "Per-object scale_factor is not baked into simready GLBs. " + "For text input, simready_geometry_path multiplied by this " + "scale_factor gives the estimated real-world size." + ), + } + ) + log_info(f"text scene metric scale completed count={len(estimates)}") + except Exception as exc: + result.update({"status": "failed", "reason": traceback.format_exc()}) + MetricScaleManager.set_for_all_objects( + objects=object_results, + status="failed", + reason="text_scene_metric_scale_failed", + method=str(result["method"]), + ) + log_warning(f"text scene metric scale failed error={exc}") + return result + + +def build_metric_scale_inputs( + *, + objects: list[dict[str, Any]], + output_root: Path, +) -> list[MetricScaleObjectInput]: + inputs: list[MetricScaleObjectInput] = [] + for obj in objects: + mesh_path = _resolve_generated_path( + obj.get("simready_geometry_path") or obj.get("mesh_path"), + output_root, + ) + if not mesh_path.is_file(): + raise FileNotFoundError(f"Simready object GLB not found: {mesh_path}") + inputs.append( + MetricScaleObjectInput( + object_id=str(obj.get("id", "")), + object_name=str(obj.get("name", "")), + object_description=str(obj.get("description", "")), + mesh_path=mesh_path, + ) + ) + return inputs + + +def _resolve_generated_path(value: Any, output_root: Path) -> Path: + path = Path(str(value or "")).expanduser() + if path.is_absolute(): + return path.resolve() + return (output_root / path).resolve() diff --git a/embodichain/gen_sim/prompt2scene/cli/__init__.py b/embodichain/gen_sim/prompt2scene/cli/__init__.py new file mode 100644 index 000000000..015c41510 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/cli/__init__.py @@ -0,0 +1,19 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +__all__: list[str] = [] diff --git a/embodichain/gen_sim/prompt2scene/cli/start.py b/embodichain/gen_sim/prompt2scene/cli/start.py new file mode 100644 index 000000000..fdc3a27b5 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/cli/start.py @@ -0,0 +1,90 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +import argparse +from pathlib import Path + +from embodichain.gen_sim.prompt2scene.pipeline.runner import run_prompt2scene +from embodichain.gen_sim.prompt2scene.llms import load_llm_config +from embodichain.gen_sim.prompt2scene.workflows.request import Prompt2SceneInput + +__all__ = ["cli_prompt2scene", "main"] + + +def cli_prompt2scene( + image_path: str | None, + text: str | None, + output_root: str, + llm_config_path: str | None = None, +) -> None: + """Run prompt2scene from normalized CLI argument values. + + Args: + image_path: Path to an input image, if image mode is used. + text: Text prompt, if text mode is used. + output_root: Directory where prompt2scene outputs are written. + llm_config_path: Optional path to the LLM config JSON file. + """ + request = Prompt2SceneInput.from_cli_args( + image_path=Path(image_path) if image_path is not None else None, + text=text, + output_root=Path(output_root), + ) + llm_cfg = load_llm_config( + Path(llm_config_path) if llm_config_path is not None else None + ) + run_prompt2scene(request, llm_cfg=llm_cfg) + + +def main() -> None: + """Parse command line arguments and launch the prompt2scene pipeline.""" + parser = argparse.ArgumentParser( + description="embodichain.gen_sim.prompt2scene Prompt-to-Scene Pipeline" + ) + + input_group = parser.add_mutually_exclusive_group(required=True) + input_group.add_argument( + "--image", + type=str, + help="Path to the input image file (.jpg, .jpeg, or .png)", + ) + input_group.add_argument( + "--text", + type=str, + help="Text prompt describing the target scene", + ) + parser.add_argument( + "--output_root", + type=str, + required=True, + help="Path to the output directory", + ) + parser.add_argument( + "--llm_config", + type=str, + default=None, + help="Path to the LLM config JSON file", + ) + + args = parser.parse_args() + + cli_prompt2scene(args.image, args.text, args.output_root, args.llm_config) + + +if __name__ == "__main__": + main() diff --git a/embodichain/gen_sim/prompt2scene/configs/client_config.json b/embodichain/gen_sim/prompt2scene/configs/client_config.json new file mode 100644 index 000000000..b8662eaf2 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/configs/client_config.json @@ -0,0 +1,21 @@ +{ + "sam3_segmentation": { + "base_url": "http://192.168.3.23:5014", + "timeout_s": 1200, + "health_path": "/health", + "segment_single_object_path": "/predict" + }, + "sam3d_generation": { + "base_url": "http://10.7.7.32:5019", + "timeout_s": 1800, + "health_path": "/health", + "generate_multiple_objects_path": "/generate_multiple_objects", + "generate_single_object_path": "/generate_single_object" + }, + "zimage": { + "base_url": "http://192.168.3.23:5013", + "timeout_s": 120, + "health_path": "/health", + "generate_single_object_path": "/generate.png" + } +} diff --git a/embodichain/gen_sim/prompt2scene/configs/llm_config.json b/embodichain/gen_sim/prompt2scene/configs/llm_config.json new file mode 100644 index 000000000..9dd825143 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/configs/llm_config.json @@ -0,0 +1,11 @@ +{ + "llm": { + "openai_compatible": { + "api_key": "", + "model": "", + "base_url": "", + "default_query": {}, + "max_attempts": 5 + } + } +} diff --git a/embodichain/gen_sim/prompt2scene/llms/__init__.py b/embodichain/gen_sim/prompt2scene/llms/__init__.py new file mode 100644 index 000000000..8412eff44 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/llms/__init__.py @@ -0,0 +1,31 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from embodichain.gen_sim.prompt2scene.llms.config import OpenAICompatibleLLMCfg +from embodichain.gen_sim.prompt2scene.llms.openai_compatible import ( + DEFAULT_LLM_CONFIG_PATH, + build_chat_model, + load_llm_config, +) + +__all__ = [ + "DEFAULT_LLM_CONFIG_PATH", + "OpenAICompatibleLLMCfg", + "build_chat_model", + "load_llm_config", +] diff --git a/embodichain/gen_sim/prompt2scene/llms/config.py b/embodichain/gen_sim/prompt2scene/llms/config.py new file mode 100644 index 000000000..f84c4fcf9 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/llms/config.py @@ -0,0 +1,49 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from dataclasses import dataclass, field + +__all__ = [ + "OpenAICompatibleLLMCfg", +] + + +@dataclass(frozen=True) +class OpenAICompatibleLLMCfg: + """OpenAI-compatible LLM configuration.""" + + api_key: str + model: str + base_url: str + default_query: dict[str, str] = field(default_factory=dict) + max_attempts: int = 3 + + def to_manifest(self) -> dict[str, object]: + """Convert the LLM config to a JSON-safe manifest. + + Returns: + LLM config metadata with sensitive values removed. + """ + return { + "provider": "openai_compatible", + "model": self.model, + "base_url": self.base_url, + "has_api_key": bool(self.api_key), + "default_query": self.default_query, + "max_attempts": self.max_attempts, + } diff --git a/embodichain/gen_sim/prompt2scene/llms/openai_compatible.py b/embodichain/gen_sim/prompt2scene/llms/openai_compatible.py new file mode 100644 index 000000000..91e94a594 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/llms/openai_compatible.py @@ -0,0 +1,115 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +import json +import os +from pathlib import Path +from typing import Any + +from langchain_openai import ChatOpenAI + +from embodichain.gen_sim.prompt2scene.llms.config import OpenAICompatibleLLMCfg + +__all__ = ["DEFAULT_LLM_CONFIG_PATH", "build_chat_model", "load_llm_config"] + +DEFAULT_LLM_CONFIG_PATH = ( + Path(__file__).resolve().parents[1] / "configs" / "llm_config.json" +) + + +def load_llm_config(config_path: Path | None = None) -> OpenAICompatibleLLMCfg: + """Load the prompt2scene OpenAI-compatible LLM config. + + Args: + config_path: Optional path to the LLM config JSON file. + + Returns: + Parsed OpenAI-compatible LLM config. + + Raises: + FileNotFoundError: If the config file does not exist. + ValueError: If required config fields are missing. + """ + config_path = config_path or DEFAULT_LLM_CONFIG_PATH + config_path = config_path.expanduser().resolve() + + if not config_path.exists(): + raise FileNotFoundError(f"LLM config not found: {config_path}") + + with config_path.open("r", encoding="utf-8") as f: + raw_cfg: dict[str, Any] = json.load(f) + + cfg = raw_cfg.get("llm", {}).get("openai_compatible", {}) + api_key = os.getenv("OPENAI_API_KEY") or cfg.get("api_key", "") + model = os.getenv("OPENAI_MODEL") or cfg.get("model", "") + base_url = os.getenv("OPENAI_BASE_URL") or cfg.get("base_url", "") + default_query = cfg.get("default_query", {}) + max_attempts = _load_positive_int( + os.getenv("OPENAI_MAX_ATTEMPTS") or cfg.get("max_attempts", 3), + key="max_attempts", + ) + + if base_url: + base_url = base_url.rstrip("/") + + missing = [ + name + for name, value in { + "api_key": api_key, + "model": model, + "base_url": base_url, + }.items() + if not value + ] + if missing: + raise ValueError(f"Missing required LLM config keys: {missing}") + + if not isinstance(default_query, dict): + raise ValueError("LLM config key default_query must be a dict.") + + return OpenAICompatibleLLMCfg( + api_key=api_key, + model=model, + base_url=base_url, + default_query=default_query, + max_attempts=max_attempts, + ) + + +def _load_positive_int(value: object, *, key: str) -> int: + try: + parsed = int(value) + except (TypeError, ValueError) as exc: + raise ValueError(f"LLM config key {key} must be an integer.") from exc + if parsed < 1: + raise ValueError(f"LLM config key {key} must be >= 1.") + return parsed + + +def build_chat_model(cfg: OpenAICompatibleLLMCfg) -> Any: + """Build a LangChain OpenAI-compatible chat model.""" + kwargs: dict[str, Any] = { + "api_key": cfg.api_key, + "base_url": cfg.base_url, + "model": cfg.model, + "temperature": 0, + } + if cfg.default_query: + kwargs["default_query"] = cfg.default_query + + return ChatOpenAI(**kwargs) diff --git a/embodichain/gen_sim/prompt2scene/pipeline/__init__.py b/embodichain/gen_sim/prompt2scene/pipeline/__init__.py new file mode 100644 index 000000000..a1450f03c --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/pipeline/__init__.py @@ -0,0 +1,25 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from embodichain.gen_sim.prompt2scene.pipeline.runner import ( + Prompt2SceneRunResult, + run_prompt2scene, +) + +__all__ = ["Prompt2SceneRunResult", "run_prompt2scene"] + diff --git a/embodichain/gen_sim/prompt2scene/pipeline/runner.py b/embodichain/gen_sim/prompt2scene/pipeline/runner.py new file mode 100644 index 000000000..7931f00ba --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/pipeline/runner.py @@ -0,0 +1,239 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path + +from embodichain.gen_sim.prompt2scene.llms import OpenAICompatibleLLMCfg +from embodichain.gen_sim.prompt2scene.workflows.request import ( + InputKind, + Prompt2SceneInput, +) +from embodichain.gen_sim.prompt2scene.workflows.artifact_writer import ( + IMAGE_SEGMENTS_STEP, + IMAGE_SPATIAL_RELATIONS_STEP, + SCENE_INTAKE_STEP, + STEP_RESULT_FILENAME, + step_result_path, + write_step_result, + TEXT_RELATIONS_STEP, + UNIFIED_SCENE_STEP, +) +from embodichain.gen_sim.prompt2scene.workflows.unified_scene.graph import ( + run_unified_scene, +) +from embodichain.gen_sim.prompt2scene.workflows.unified_scene_gen.graph import ( + run_unified_scene_gen, +) +from embodichain.gen_sim.prompt2scene.agent_tools.tools.gym_export import ( + export_gym_config, +) +from embodichain.gen_sim.prompt2scene.utils.io import write_json +from embodichain.gen_sim.prompt2scene.utils import log +from embodichain.gen_sim.prompt2scene.workflows.image_relations import ( + run_image_relations, +) +from embodichain.gen_sim.prompt2scene.workflows.scene_intake import run_scene_intake +from embodichain.gen_sim.prompt2scene.workflows.text_relations import ( + run_text_relations, +) + +__all__ = [ + "IMAGE_SEGMENTS_DIRNAME", + "IMAGE_SPATIAL_RELATIONS_DIRNAME", + "INPUT_MANIFEST_FILENAME", + "SCENE_INTAKE_DIRNAME", + "STEP_RESULT_FILENAME", + "TEXT_RELATIONS_DIRNAME", + "UNIFIED_SCENE_DIRNAME", + "Prompt2SceneRunResult", + "run_prompt2scene", +] + +INPUT_MANIFEST_FILENAME = "input_manifest.json" +SCENE_INTAKE_DIRNAME = SCENE_INTAKE_STEP +IMAGE_SEGMENTS_DIRNAME = IMAGE_SEGMENTS_STEP +IMAGE_SPATIAL_RELATIONS_DIRNAME = IMAGE_SPATIAL_RELATIONS_STEP +TEXT_RELATIONS_DIRNAME = TEXT_RELATIONS_STEP +UNIFIED_SCENE_DIRNAME = UNIFIED_SCENE_STEP + + +@dataclass(frozen=True) +class Prompt2SceneRunResult: + """Result returned by the prompt2scene runner. + + Args: + output_root: Directory where prompt2scene outputs were written. + manifest_path: Path to the serialized input manifest. + scene_intake_path: Path to the serialized scene intake output. + image_segments_path: Path to serialized image segment alignment output. + image_spatial_relations_path: Path to serialized image spatial relations. + text_relations_path: Path to serialized text spatial relations. + unified_scene_path: Path to serialized unified scene output. + """ + + output_root: Path + manifest_path: Path + scene_intake_path: Path | None = None + image_segments_path: Path | None = None + image_spatial_relations_path: Path | None = None + text_relations_path: Path | None = None + unified_scene_path: Path | None = None + gym_config_path: Path | None = None + + +def run_prompt2scene( + request: Prompt2SceneInput, + llm_cfg: OpenAICompatibleLLMCfg | None = None, +) -> Prompt2SceneRunResult: + """Run the prompt2scene pipeline. + + This runner creates the output directory, writes the parsed input manifest, + and runs fixed VLM-based scene intake when an LLM config is provided. + + Args: + request: Parsed prompt2scene input. + llm_cfg: Optional LLM config used by later pipeline stages. + + Returns: + Paths created by the runner. + """ + log.log_info( + "run start " + f"input_kind={request.input_kind.value} output_root={request.output_root}" + ) + request.output_root.mkdir(parents=True, exist_ok=True) + manifest_path = request.output_root / INPUT_MANIFEST_FILENAME + manifest = request.to_manifest() + if llm_cfg is not None: + manifest["llm"] = llm_cfg.to_manifest() + write_json(manifest_path, manifest) + + scene_intake_path = None + image_segments_path = None + image_spatial_relations_path = None + text_relations_path = None + unified_scene_path = None + gym_config_path = None + if llm_cfg is not None: + log.log_info("step start scene_intake") + scene_intake = run_scene_intake(request, llm_cfg=llm_cfg) + scene_intake_path = write_step_result( + request.output_root, + SCENE_INTAKE_STEP, + scene_intake.to_manifest(), + ) + log.log_info( + f"step end scene_intake status=ok output={scene_intake_path}" + ) + if request.input_kind == InputKind.IMAGE: + log.log_info("step start image_relations") + image_relations = run_image_relations( + request, + scene_intake=scene_intake, + llm_cfg=llm_cfg, + output_root=request.output_root, + ) + image_segments_path = step_result_path( + request.output_root, + IMAGE_SEGMENTS_STEP, + ) + if not image_segments_path.is_file(): + write_step_result( + request.output_root, + IMAGE_SEGMENTS_STEP, + image_relations.to_segmentation_manifest(), + ) + image_spatial_relations_path = step_result_path( + request.output_root, + IMAGE_SPATIAL_RELATIONS_STEP, + ) + if not image_spatial_relations_path.is_file(): + write_step_result( + request.output_root, + IMAGE_SPATIAL_RELATIONS_STEP, + image_relations.to_spatial_manifest(), + ) + log.log_info( + "step end image_relations " + f"status=ok output={image_spatial_relations_path}" + ) + log.log_info("step start unified_scene") + unified_scene = run_unified_scene( + request, + scene_intake=scene_intake, + image_relations=image_relations, + output_root=request.output_root, + ) + unified_scene_path = step_result_path( + request.output_root, + UNIFIED_SCENE_STEP, + ) + else: + log.log_info("step start text_relations") + text_relations = run_text_relations( + request, + scene_intake=scene_intake, + llm_cfg=llm_cfg, + output_root=request.output_root, + ) + text_relations_path = step_result_path( + request.output_root, + TEXT_RELATIONS_STEP, + ) + log.log_info( + f"step end text_relations status=ok output={text_relations_path}" + ) + log.log_info("step start unified_scene") + unified_scene = run_unified_scene( + request, + scene_intake=scene_intake, + text_relations=text_relations, + output_root=request.output_root, + ) + unified_scene_path = step_result_path( + request.output_root, + UNIFIED_SCENE_STEP, + ) + log.log_info( + f"step end unified_scene status=ok output={unified_scene_path}" + ) + log.log_info("step start unified_scene_gen") + run_unified_scene_gen( + request.output_root, + unified_scene_result_path=unified_scene_path, + llm_cfg=llm_cfg, + ) + log.log_info("step end unified_scene_gen status=ok") + + log.log_info("step start gym_export") + gym_config_path = export_gym_config(request.output_root) + log.log_info(f"step end gym_export status=ok output={gym_config_path}") + + log.log_info(f"run end output_root={request.output_root}") + + return Prompt2SceneRunResult( + output_root=request.output_root, + manifest_path=manifest_path, + scene_intake_path=scene_intake_path, + image_segments_path=image_segments_path, + image_spatial_relations_path=image_spatial_relations_path, + text_relations_path=text_relations_path, + unified_scene_path=unified_scene_path, + gym_config_path=gym_config_path, + ) diff --git a/embodichain/gen_sim/prompt2scene/prompts/__init__.py b/embodichain/gen_sim/prompt2scene/prompts/__init__.py new file mode 100644 index 000000000..f72a97f6d --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/prompts/__init__.py @@ -0,0 +1,48 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from . import data +from .base import PromptRenderer + +default_prompt_renderer = PromptRenderer(data) + +__all__ = ["load_prompt", "load_prompt_data", "render_prompt", "default_prompt_renderer"] + + +def load_prompt(prompt_name: str) -> str: + """Load a prompt template from the bundled prompt data directory.""" + return default_prompt_renderer.load_prompt(prompt_name) + + +def load_prompt_data(prompt_name: str) -> dict[str, object]: + """Load a YAML prompt data file from the bundled prompt data directory.""" + return default_prompt_renderer.load_prompt_data(prompt_name) + + +def render_prompt( + prompt_name: str, + values: dict[str, object] | None = None, + *, + prompt_key: str | None = None, +) -> str: + """Load a prompt template and fill optional placeholders.""" + return default_prompt_renderer.render_prompt( + prompt_name, + values, + prompt_key=prompt_key, + ) diff --git a/embodichain/gen_sim/prompt2scene/prompts/base.py b/embodichain/gen_sim/prompt2scene/prompts/base.py new file mode 100644 index 000000000..a145735cb --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/prompts/base.py @@ -0,0 +1,79 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from functools import lru_cache +from importlib import resources +from pathlib import Path +from string import Template +from typing import Any, Mapping + +import yaml + +__all__ = ["PromptRenderer"] + + +class PromptRenderer: + """Load and render bundled prompt templates.""" + + def __init__(self, package: Any) -> None: + self._package = package + + @lru_cache(maxsize=None) + def load_prompt(self, prompt_name: str) -> str: + """Load a plain-text prompt template by file name.""" + prompt_path = self._get_prompt_path(prompt_name) + if not prompt_path.is_file(): + raise FileNotFoundError(f"Prompt data file not found: {prompt_name}") + return prompt_path.read_text(encoding="utf-8").strip() + + @lru_cache(maxsize=None) + def load_prompt_data(self, prompt_name: str) -> dict[str, Any]: + """Load a YAML prompt data file by file name.""" + prompt_path = self._get_prompt_path(prompt_name) + if not prompt_path.is_file(): + raise FileNotFoundError(f"Prompt data file not found: {prompt_name}") + + prompt_data = yaml.safe_load(prompt_path.read_text(encoding="utf-8")) + if not isinstance(prompt_data, dict): + raise ValueError(f"Prompt YAML must contain a mapping: {prompt_name}") + return prompt_data + + def render_prompt( + self, + prompt_name: str, + values: Mapping[str, object] | None = None, + *, + prompt_key: str | None = None, + ) -> str: + """Render a prompt template and fill placeholders.""" + if prompt_key is None: + template = self.load_prompt(prompt_name) + else: + prompt_data = self.load_prompt_data(prompt_name) + template = prompt_data.get(prompt_key) + if not isinstance(template, str): + raise KeyError(f"Prompt key {prompt_key!r} not found in {prompt_name}") + + if values is None: + return template + return Template(template).safe_substitute(values) + + def _get_prompt_path(self, prompt_name: str) -> Path: + if "/" in prompt_name or "\\" in prompt_name: + raise ValueError(f"Prompt name must be a file name: {prompt_name}") + return resources.files(self._package).joinpath(prompt_name) diff --git a/embodichain/gen_sim/prompt2scene/prompts/data/__init__.py b/embodichain/gen_sim/prompt2scene/prompts/data/__init__.py new file mode 100644 index 000000000..96d642123 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/prompts/data/__init__.py @@ -0,0 +1,21 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +"""Bundled prompt template data files.""" + +from __future__ import annotations + +__all__: list[str] = [] diff --git a/embodichain/gen_sim/prompt2scene/prompts/data/image_relations.yaml b/embodichain/gen_sim/prompt2scene/prompts/data/image_relations.yaml new file mode 100644 index 000000000..50ed69647 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/prompts/data/image_relations.yaml @@ -0,0 +1,238 @@ +name: image_relations +version: 1 + +filter_extra_instances_system: | + + You are a careful image segmentation verification assistant for tabletop scenes. + + + + You will receive: + - One target object class name. + - One target object description. + - The expected number of target instances. + - A short candidate class list for that target object. + - One image with numbered colored masks drawn over candidate segmentation + results for that target object. + + Your only task is to choose which numbered masks should be removed so the + remaining masks best match the requested object class, target description, and + expected instance count. + + This is not a scene-description task and not a spatial-relation task. + Do not describe the scene. Do not infer object-object relations. Do not rename + the requested object class. Do not add new masks. + + + + - Use the target object class name as the primary class. + - Use the target description to distinguish visually similar objects from the + same broad category. + - Use the expected instance count as a hard target when enough plausible masks + are available. + - Use the candidate class list only as synonyms or fallback names for the same + target object. + - If more plausible masks are present than the expected count, keep only the + expected number of best matches and remove the rest. + - If exactly the expected number of plausible masks are present, keep them. + - If fewer than the expected number of plausible masks are present, keep every + plausible mask and remove only clearly wrong or duplicate masks. + - Remove a numbered mask if it clearly covers a different object class. + - Remove a numbered mask if it is a duplicate detection of the same physical + instance already covered by another better mask. + - Remove a numbered mask if it mostly covers background, a hand, or an + unrelated partial region. + - Remove a numbered mask that mostly covers a table or support region unless + the requested target class itself is that table/support target. + - If a mask is ambiguous but plausibly covers the requested object class, keep + it. + + + + { + "extra_instance_numbers": [3], + "reason": "Mask 3 covers a different object, not the requested class." + } + + + + Example 1: + Target object class: soccer_ball + Target description: A round soccer ball with black-and-white panels. + Expected instance count: 2 + Candidate classes: soccer_ball, football, ball, sports_ball, toy_ball + Observation: Masks 1 and 2 cover two soccer balls. Mask 3 covers a paper cup. + Output: + { + "extra_instance_numbers": [3], + "reason": "Masks 1 and 2 are soccer balls; mask 3 is a paper cup." + } + + Example 2: + Target object class: apple + Target description: A round red apple with smooth skin. + Expected instance count: 1 + Candidate classes: apple, fruit, red_apple, food, produce + Observation: Mask 1 tightly covers the apple. Mask 2 overlaps the same apple and + is a duplicate looser detection. + Output: + { + "extra_instance_numbers": [2], + "reason": "Mask 2 is a duplicate detection of the same apple covered by mask 1." + } + + Example 3: + Target object class: mug + Target description: A white ceramic coffee mug with a handle. + Expected instance count: 1 + Candidate classes: mug, coffee_mug, cup, drinkware, ceramic_cup + Observation: Mask 1 covers a real mug. Mask 2 covers a bowl. + Output: + { + "extra_instance_numbers": [2], + "reason": "Mask 1 is a mug; mask 2 is a bowl and should be removed." + } + + Example 4: + Target object class: fork + Target description: A silver metal fork with four tines. + Expected instance count: 1 + Candidate classes: fork, dinner_fork, utensil, cutlery, tableware + Observation: Mask 1 plausibly covers a fork, although part of it is occluded. + Output: + { + "extra_instance_numbers": [], + "reason": "Mask 1 plausibly covers the requested fork and should be kept." + } + + + + - extra_instance_numbers must contain 1-based mask numbers exactly as shown in + the numbered-mask image. + - If no masks should be removed, output an empty list. + - Output JSON only. Do not include markdown or explanations outside JSON. + + +filter_extra_instances_user: | + Verify the numbered segmentation masks for this object class: + + + Target object class: $name + Target description: $description + Expected instance count: $expected_count + Candidate classes: $class_candidate + + + + Inspect the numbered-mask image. + Return the 1-based numbers of masks that should be removed so the remaining + masks best match the target description and expected instance count. + + +spatial_layout_system: | + + You are a careful tabletop spatial-layout verifier. + + + + You will receive one tabletop image with final bounding boxes and labels for + every detected object instance. Your task is to output: + - One anchor object, its 9-grid table location, and the reason for choosing it + and assigning that grid. + - Object groups ordered from left to right. + - Object groups ordered from front to back. + - Whether each object has arbitrary layout, plus a concise support-pose reason. + + Do not output pairwise left/right/front/behind relations. The program will + derive canonical left_of and front_of relations from your x_order and y_order. + Use ordered groups conservatively. Prefer fewer relations over a wrong + relation. + + + + - x_order must be ordered from image/table left to image/table right. + - y_order must be ordered from table front to table back. + - Split x_order groups when the left/right order is reasonably clear from the + bbox-name image. + - If an object's left/right order is ambiguous, keep it in a shared x_order + group. Never omit it. + - Front/back is especially hard to judge. Split y_order only when depth + separation is obvious, preferably from contact positions or bbox bottoms. + - If front/back is close, roughly collinear, overlapping, occluded, similarly + aligned, or hard to compare, place objects in the same y_order group. + - Ordered groups are interpreted as monotonic DAG ranks. The program only + creates direct edges between adjacent groups, then derives transitive + closure. For example, G1 < G2 < G3 creates direct edges G1 -> G2 and + G2 -> G3; G1 -> G3 is implicit. + + + + - Choose one clearly visible object as anchor. + - Prefer a large, unoccluded object whose 9-grid location is easy to judge. + - The anchor reason must explain both why this object was selected and why its + grid is correct. + - The anchor grid must be one of: + center, front, back, left_center, right_center, left_front, right_front, + left_back, right_back. + + + + - is_arbitrary_layout is true when the object does not need a specified + support pose before physics simulation, such as balls, round fruits, loose + natural objects, or objects that will naturally settle by gravity. + - is_arbitrary_layout is false when the object needs a deliberate support pose, + such as cups, bottles, cans, boxes, utensils, remotes, blocks, bags, or + objects that should stand or lie in a controlled way. + - If is_arbitrary_layout is false, the reason must describe the default support + pose visible or implied in the image, such as standing upright on the table, + lying flat on the table, lying on its side, or leaning against another object. + - If is_arbitrary_layout is true, the reason must explain that the object can + settle naturally under gravity or has no meaningful preset support pose. + + + + { + "anchor": { + "asset_id": "interact_paper_cup_0", + "grid": "center", + "reason": "The paper cup is clearly visible and near the table center, so it is a reliable anchor for the center grid." + }, + "x_order": [ + ["interact_wooden_block_0"], + ["interact_paper_cup_0"], + ["interact_snack_bag_0"] + ], + "y_order": [ + ["interact_paper_cup_0"], + ["interact_wooden_block_0", "interact_snack_bag_0"] + ], + "asset_states": [ + { + "asset_id": "interact_paper_cup_0", + "is_arbitrary_layout": false, + "reason": "The paper cup is standing upright on the table, so it needs a deliberate upright support pose." + } + ] + } + + + + - Every provided asset_id must appear exactly once in x_order. + - Every provided asset_id must appear exactly once in y_order. + - Every provided asset_id must appear exactly once in asset_states. + - Use one large group on an axis if the left-right or front-back order is not + visually obvious. Do not omit uncertain objects. + - anchor.asset_id must be one of the provided asset_ids. + - anchor.reason and every asset state reason must be concise but explicit. + - Only the anchor may have a grid. Do not add grid to asset_states. + - Output JSON only. Do not include markdown or explanations outside JSON. + + +spatial_layout_user: | + Infer spatial order, anchor grid, and object states for these detected object instances: + + + $asset_ids + + + Inspect the attached bbox-name image and return the JSON object. diff --git a/embodichain/gen_sim/prompt2scene/prompts/data/scene_intake.yaml b/embodichain/gen_sim/prompt2scene/prompts/data/scene_intake.yaml new file mode 100644 index 000000000..cabf99cb5 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/prompts/data/scene_intake.yaml @@ -0,0 +1,468 @@ +name: scene_intake +version: 1 + +text_system: | + + You are a careful 3D tabletop scene intake assistant for TEXT input. + + + + You will receive a text description of a tabletop scene. + This is only the first-stage scene intake step: + - Extract the object categories and counts on the tabletop. + - Extract the table or tabletop region that carries the objects, using + the fixed output field named table. + + Do not analyze object-object relations, grids, orientations, stacking, + inside/container relations, layout, pose, masks, bounding boxes, or + segmentation results. + + + + - Output only real physical objects that can become 3D asset generation targets. + - Do not include the table or tabletop region in assets. + - assets is a list of object category groups, not a list of individual object + instances. + - name must be the most specific English, singular, canonical object class + supported by the input. + - Prefer a concrete small category over a broad category. For example, output + fork instead of utensil, paper_cup instead of container, toy_car instead of + toy, remote_control instead of handheld_device, and cereal_box instead of + box when those categories are supported by the input. + - Use a broad fallback name only when the specific object category cannot be + reasonably inferred. + - Prefer snake_case names, such as apple, banana, soccer_ball, coffee_mug. + - Treat multiple objects as one repeated asset group only when they are + effectively the same object type and can share the same name, the same + object-only description, and the same class_candidate list without losing + important visual identity. + - Never output two asset rows with the same name. If the same name would be + repeated, merge them into one row and increase count. + - If repeated instances are truly the same asset group, output exactly one + asset row and set count to the number of visible or described instances. + - If two objects need meaningfully different descriptions, names, or + class_candidate lists, they are not repeated instances. Output separate + asset rows with specific different names. + - Only merge objects when they can reasonably be found by the same segmentation + prompts from name, class_candidate, and description. + - Do not merge visually different subtypes under a broad name. For example, + paper_cup and popcorn_cup must be separate rows, not one cup row; snack_bag + and paper_bag must be separate rows; remote_control and phone must be + separate rows. + - Do not output instance IDs such as apple_0 or banana_0. Instance IDs will be + generated by code from name and count. + - Do not output extra fields such as source_text, source_image_path, image_path, + bbox, mask, or id. + - class_candidate must contain exactly five English, singular, canonical + object class names that could help later image detection or segmentation. + - class_candidate must prioritize specific small categories. The first item + must equal name. The next items should be specific plausible classes before + broader fallback classes. + - Do not replace a known small category with a broad category. If the object is + a fork, include fork first; broader classes such as utensil or cutlery may + appear only later as fallbacks. + - For text inputs, class_candidate should follow the stated object category + and include detector-friendly small-category synonyms before broader + classes. + + + + - table.name, table.description, table.complete_table_description, + table.class_candidate, and every asset.description must be non-empty. + - Descriptions are used to generate images and then 3D geometry. + - Write each description as one concise English sentence, normally 10 to 25 + words. + - Every description must describe a SINGLE STANDALONE OBJECT isolated on a + pure-white background. Do NOT mention any other object, the table, the scene, + the room, or any background context. + - Do NOT include any spatial, positional, or layout information such as + "sitting on the table", "placed in front of", "to the left of", "on a + surface", "on the tabletop", etc. + - When describing an object, first state what the object is, then describe its + appearance in detail. + - For TEXT input you MUST invent reasonable and vivid appearance details: + color (be specific: "crimson red", "matte charcoal", "glossy navy blue", + "warm honey oak"), material (polished stainless steel, glazed ceramic, + rough terracotta, smooth beechwood, frosted glass), texture (ribbed, + brushed, speckled, woven, hammered), shape (cylindrical, tapered, flared + rim, curved handle, wide brim). + - Vary colours across objects — do not make everything white or neutral. + A tabletop scene naturally has diverse materials and hues. + - table.description must describe the actual table as a standalone target: + include type, color, shape, material, and legs/base when applicable. + - table.complete_table_description must describe a complete standalone table + asset for generation. It must always include a complete physical table-like + object, with a tabletop and a plausible support structure such as legs, + pedestal, frame, or tray body. It must not describe only a surface plane, + tabletop patch, texture, or support region. + - Do not write generic phrases such as "support surface", "tabletop", or + "surface" when table.name is a concrete object such as table, desk, tray, + counter, shelf, or floor. Use the concrete class in the description. + - For repeated instances, write one object-only description for the shared + category. Do not mention instance positions. + - If two objects require different descriptions, they must be separate asset + rows with distinct names. + + + + - Do not output a table id. The code will set table.id to "table". + - The table field represents the scene table or tabletop target. table.name + must be the best class name for that target, such as table, desk, + dining_table, coffee_table, workbench, or tabletop. + - table.class_candidate must contain exactly five English, singular, + canonical class names for segmenting the support target. The first item must + equal table.name. + - For text inputs, set table.is_complete_visible_table to false. + + + + { + "table": { + "name": "table", + "description": "A rectangular wooden table with a brown top and four straight legs.", + "complete_table_description": "A complete rectangular wooden table with a brown top and four straight legs.", + "is_complete_visible_table": false, + "class_candidate": ["table", "dining_table", "desk", "wooden_table", "furniture"] + }, + "assets": [ + { + "name": "apple", + "description": "A shiny deep-red apple with a smooth curved shape and a small brown stem on top.", + "class_candidate": ["apple", "fruit", "red_apple", "food", "produce"], + "count": 1 + }, + { + "name": "coffee_mug", + "description": "A glossy navy blue ceramic coffee mug with a curved handle and a slightly flared rim.", + "class_candidate": ["coffee_mug", "ceramic_mug", "mug", "cup", "drinkware"], + "count": 2 + } + ] + } + + + + - The top-level object must contain only table and assets. + - table must contain only name, description, complete_table_description, + is_complete_visible_table, and class_candidate. + - Each asset must contain only name, description, class_candidate, and count. + - table.name must be a non-empty string. + - table.description must be a non-empty string. + - table.complete_table_description must be a non-empty string. + - table.is_complete_visible_table must be a boolean. + - table.class_candidate must be a list of exactly five non-empty strings, and + the first item must equal table.name. + - assets must be a list. + - Each asset.name must be a non-empty string. + - Each asset.description must be a non-empty string. + - Each asset.class_candidate must be a list of exactly five non-empty strings. + - Each asset.count must be an integer greater than or equal to 1. + - Output JSON only. Do not include markdown or explanations outside JSON. + + +image_system: | + + You are a careful 3D tabletop scene intake assistant for IMAGE input. + + + + You will receive one image of a tabletop scene. + This is only the first-stage scene intake step: + - Extract the object categories and counts on the tabletop. + - Extract the visible table or tabletop region that carries the objects, using + the fixed output field named table. + + Do not analyze object-object relations, grids, orientations, stacking, + inside/container relations, layout, pose, masks, bounding boxes, or + segmentation results. + + + + - Output only real physical objects that can become 3D asset generation targets. + - Do not include the table or tabletop region in assets. + - assets is a list of object category groups, not a list of individual object + instances. + - name must be the most specific English, singular, canonical object class + supported by the input. + - Prefer a concrete small category over a broad category. For example, output + fork instead of utensil, paper_cup instead of container, toy_car instead of + toy, remote_control instead of handheld_device, and cereal_box instead of + box when those categories are supported by the input. + - Use a broad fallback name only when the specific object category cannot be + reasonably inferred. + - Prefer snake_case names, such as apple, banana, soccer_ball, coffee_mug. + - Treat multiple objects as one repeated asset group only when they are + effectively the same object type and can share the same name, the same + object-only description, and the same class_candidate list without losing + important visual identity. + - Never output two asset rows with the same name. If the same name would be + repeated, merge them into one row and increase count. + - If repeated instances are truly the same asset group, output exactly one + asset row and set count to the number of visible or described instances. + - If two objects need meaningfully different descriptions, names, or + class_candidate lists, they are not repeated instances. Output separate + asset rows with specific different names. + - Only merge objects when they can reasonably be found by the same segmentation + prompts from name, class_candidate, and description. + - Do not merge visually different subtypes under a broad name. For example, + paper_cup and popcorn_cup must be separate rows, not one cup row; snack_bag + and paper_bag must be separate rows; remote_control and phone must be + separate rows. + - Do not output instance IDs such as apple_0 or banana_0. Instance IDs will be + generated by code from name and count. + - Do not output extra fields such as source_text, source_image_path, image_path, + bbox, mask, or id. + - class_candidate must contain exactly five English, singular, canonical + object class names that could help later image detection or segmentation. + - class_candidate must prioritize specific small categories. The first item + must equal name. The next items should be specific plausible classes before + broader fallback classes. + - Do not replace a known small category with a broad category. If the object is + a fork, include fork first; broader classes such as utensil or cutlery may + appear only later as fallbacks. + - For image inputs, if the exact object category is uncertain, use + class_candidate to list likely categories from specific to broader, such as + remote_control, handheld_device, electronic_device, gadget, tool. + + + + - table.name, table.description, table.complete_table_description, + table.class_candidate, and every asset.description must be non-empty. + - Descriptions are used to generate images and then 3D geometry. + - Write each description as one concise English sentence, normally 8 to 20 + words. + - Every description must describe a SINGLE STANDALONE OBJECT isolated on a + pure-white background. Do NOT mention any other object, the table, the scene, + the room, or any background context. + - Do NOT include any spatial, positional, or layout information such as + "sitting on the table", "placed in front of", "to the left of", "on a + surface", "on the tabletop", etc. + - When describing an object, first state what the object is, then mention + visible texture, color, shape, material, and similar appearance details. + - Keep descriptions simple. Focus only on what the object looks like, not + where it is or how it relates to anything else. + - For IMAGE inputs, include ONLY information supported by the image. + Do NOT invent or embellish details not visible in the image. If a colour + is ambiguous, use a reasonable neutral description ("light-colored", + "dark-toned", "metallic"). + - table.description must describe the actual visible table or tabletop region + as a standalone target. If the complete table is visible, describe that + physical table directly, including type, color, shape, material, and legs + when visible. If only a partial tabletop is visible, describe that visible + tabletop area directly. + - table.complete_table_description must describe a complete standalone table + asset for generation. If only a partial tabletop is visible, convert that + partial surface into a complete table description with matching color, + material, and texture. + - table.complete_table_description must always include a complete physical + table-like object, with a tabletop and a plausible support structure such as + legs, pedestal, frame, or tray body. It must not describe only a surface + plane, tabletop patch, texture, or support region. + - Do not write generic phrases such as "support surface", "tabletop", or + "surface" when table.name is a concrete object such as table, desk, tray, + counter, shelf, or floor. Use the concrete class in the description. + - For repeated instances, write one object-only description for the shared + category. Do not mention instance positions. + - If two objects require different descriptions, they must be separate asset + rows with distinct names. + + + + - Do not output a table id. The code will set table.id to "table". + - The table field represents the scene table or tabletop target. table.name + must be the best visible class name for that target, such as table, desk, + dining_table, coffee_table, workbench, or tabletop. + - table.class_candidate must contain exactly five English, singular, + canonical class names for segmenting the support target. The first item must + equal table.name. + - For image inputs, set table.is_complete_visible_table to true only when a + mostly complete table or desk is visible and suitable as the final table + geometry source. "Mostly complete" means both the tabletop outline/shape is + mostly visible and the table/desk legs or support structure are mostly + visible. + - Set table.is_complete_visible_table to false when only a cropped tabletop + patch, partial table surface, or heavily occluded table is visible. + - Set table.is_complete_visible_table to false when the tabletop shape is not + mostly visible, when the legs/support structure are not visible or only + barely visible, or when the image only shows a surface plane. + - If table.is_complete_visible_table is false, table.description may describe + the visible partial tabletop, but table.complete_table_description must + describe a complete table with matching tabletop color, material, and + texture. + - If table.description describes only a visible surface or tabletop patch, + table.complete_table_description must rewrite it as a full table-like asset + with matching tabletop appearance plus plausible legs, pedestal, frame, or + support body. + + + + { + "table": { + "name": "table", + "description": "A rectangular wooden table with a brown top and four straight legs.", + "complete_table_description": "A complete rectangular wooden table with a brown top and four straight legs.", + "is_complete_visible_table": false, + "class_candidate": ["table", "dining_table", "desk", "wooden_table", "furniture"] + }, + "assets": [ + { + "name": "apple", + "description": "A round apple with smooth red skin visible on the table.", + "class_candidate": ["apple", "fruit", "red_apple", "food", "produce"], + "count": 1 + }, + { + "name": "coffee_mug", + "description": "A white ceramic coffee mug with a curved handle.", + "class_candidate": ["coffee_mug", "ceramic_mug", "mug", "cup", "drinkware"], + "count": 2 + } + ] + } + + + + - The top-level object must contain only table and assets. + - table must contain only name, description, complete_table_description, + is_complete_visible_table, and class_candidate. + - Each asset must contain only name, description, class_candidate, and count. + - table.name must be a non-empty string. + - table.description must be a non-empty string. + - table.complete_table_description must be a non-empty string. + - table.is_complete_visible_table must be a boolean. + - table.class_candidate must be a list of exactly five non-empty strings, and + the first item must equal table.name. + - assets must be a list. + - Each asset.name must be a non-empty string. + - Each asset.description must be a non-empty string. + - Each asset.class_candidate must be a list of exactly five non-empty strings. + - Each asset.count must be an integer greater than or equal to 1. + - Output JSON only. Do not include markdown or explanations outside JSON. + + +text_user: | + Extract the objects and support target from this text: + $text + +image_user: | + Extract tabletop objects and the visible support target from this image. + +verifier_system: | + + You are a strict scene-intake verifier for tabletop object grouping. + + + + You will receive an original tabletop input and a draft scene_intake JSON. + Verify and correct the draft so it follows the same scene_intake schema. + + Your main job is to check: + - Whether asset groups are correctly merged or split. + - Whether each asset count matches the visible or described instance count. + - Whether each name is specific enough for later image segmentation. + - Whether table.name, table.description, table.complete_table_description, + table.is_complete_visible_table, and table.class_candidate describe the + actual table/tabletop target. + - For image inputs, independently re-check table.is_complete_visible_table + against the original image. + - Independently re-check that table.complete_table_description describes a + complete standalone table/desk/workbench/tray-like asset, not only a surface + plane, tabletop patch, texture, or support region. + + Return the corrected scene_intake JSON. Do not return comments, diffs, or + explanations. + + + + - assets is a list of object category groups, not individual instances. + - Use count to represent repeated instances only when they can share the same + name, object-only description, and class_candidate list. + - If two objects need different descriptions, names, or class_candidate lists, + split them into separate asset rows with specific names. + - Never keep two asset rows with the same name. If they are truly repeated + instances, merge them and increase count. If they are not truly the same, + rename them into more specific different names. + - Do not merge visually different subtypes under a broad name. For example, + paper_cup and popcorn_cup must be separate rows, not one cup row. + - Prefer small, visually segmentable names such as fork, paper_cup, + popcorn_cup, soccer_ball, snack_bag, wooden_block. + - Avoid broad names such as object, item, utensil, container, cup, bag, toy, + box, or device when the input supports a more specific category. + - class_candidate must contain exactly five names; the first item must equal + name. + - table.class_candidate must contain exactly five names; the first item must + equal table.name. + - Preserve the fixed table field as the table/tabletop target. + - For text inputs, table.is_complete_visible_table must be false. + - For image inputs, do not trust the draft value of + table.is_complete_visible_table. Judge it again from the attached original + image. + - For image inputs, table.is_complete_visible_table is true only if a mostly + complete table is visible and suitable as final table geometry. "Mostly + complete" means both the tabletop outline/shape is mostly visible and the + table/desk legs or support structure are mostly visible. + - If only a partial tabletop is visible, table.is_complete_visible_table must + be false and table.complete_table_description must describe a complete table + with matching tabletop color, material, and texture. + - If the table/desk legs or support structure are not visible, or if the + tabletop outline/shape is not mostly visible, table.is_complete_visible_table + must be false. + - table.complete_table_description must always be a complete physical + table-like asset description, including a tabletop and a plausible support + structure such as legs, pedestal, frame, or tray body. It must not describe + only "a surface", "a tabletop surface", "a plane", "a patch", or only a + material/texture. + - If the draft table.complete_table_description describes only a visible + partial surface, rewrite it into a complete table-like object with matching + tabletop color, material, and texture plus a plausible support structure. + - For image inputs, only count clearly visible target instances. If uncertain, + use the most conservative count supported by the image. + - For text inputs, count only objects explicitly stated or strongly implied by + the text. + + + + { + "table": { + "name": "table", + "description": "A rectangular wooden table with a brown top and four straight legs.", + "complete_table_description": "A complete rectangular wooden table with a brown top and four straight legs.", + "is_complete_visible_table": false, + "class_candidate": ["table", "dining_table", "desk", "wooden_table", "furniture"] + }, + "assets": [ + { + "name": "paper_cup", + "description": "A small white paper cup with blue printed details.", + "class_candidate": ["paper_cup", "disposable_cup", "cup", "drinkware", "container"], + "count": 1 + } + ] + } + + + + - The top-level object must contain only table and assets. + - table must contain only name, description, complete_table_description, + is_complete_visible_table, and class_candidate. + - Each asset must contain only name, description, class_candidate, and count. + - Output JSON only. Do not include markdown or explanations outside JSON. + + +verifier_text_user: | + Verify and correct this draft scene_intake JSON against the original text. + + + $text + + + + $scene_intake_json + + +verifier_image_user: | + Verify and correct this draft scene_intake JSON against the attached tabletop image. + + + $scene_intake_json + diff --git a/embodichain/gen_sim/prompt2scene/prompts/data/text_relations.yaml b/embodichain/gen_sim/prompt2scene/prompts/data/text_relations.yaml new file mode 100644 index 000000000..7a267d091 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/prompts/data/text_relations.yaml @@ -0,0 +1,110 @@ +name: text_relations +version: 1 + +system: | + + You are a strict tabletop text spatial-relation extractor. + + + + Extract only spatial constraints that are explicitly stated or strongly and + directly implied by the user's text. Do not complete the full scene layout. + Do not infer unstated object positions. Output only canonical left_of and + front_of relations. Do not add inverse or transitive relations; the program + will derive transitive closure later. + + + + - object_relations: direct object-object relations stated in text. + - table_constraints: direct object-to-table 9-grid locations stated in text. + - object_layouts: direct object support-pose constraints stated in text. + + + + - Only use these relation values: left_of, front_of. + - If the text says "A is left of B", output exactly A left_of B. + - If the text says "A is right of B", output exactly B left_of A. + - If the text says "A is in front of B", output exactly A front_of B. + - If the text says "A is behind B", output exactly B front_of A. + - Do not output right_of or behind. + - Do not output transitive relations. + - Use only asset names from the provided scene-intake assets. + + + + - Only output table_constraints when the original text explicitly states an + object-to-table region. + - Valid grid values are: + center, front, back, left_center, right_center, left_front, right_front, + left_back, right_back. + - Map natural language table regions directly: + center -> center; front -> front; back -> back; left side -> left_center; + right side -> right_center; front-left -> left_front; front-right -> + right_front; back-left -> left_back; back-right -> right_back. + - If the text does not explicitly state a table region for an object, do not + create a table constraint for that object. + - Do not infer table grid locations from object-object relations. + - If no explicit table grid constraints are stated, output table_constraints + as an empty list. + + + + - Output object_layouts only when the text explicitly describes an object's + support pose or when the object category itself strongly implies arbitrary + layout, such as a ball or round fruit. + - is_arbitrary_layout is true when the object does not need a specified support + pose before physics simulation and can settle naturally under gravity. + - is_arbitrary_layout is false when the object needs a stated/default support + pose from the text. + - For non-arbitrary objects, reason must describe the support pose, such as + standing upright on the table, lying flat on the table, lying on its side, or + leaning against another object. + + + + { + "object_relations": [ + { + "subject": "paper_cup", + "relation": "left_of", + "object": "plate", + "evidence": "The text says the paper cup is left of the plate." + } + ], + "table_constraints": [ + { + "asset": "paper_cup", + "grid": "left_front", + "evidence": "The text says the paper cup is at the front-left of the table." + } + ], + "object_layouts": [ + { + "asset": "water_bottle", + "is_arbitrary_layout": false, + "reason": "The text says the water bottle is standing upright on the table." + } + ] + } + + + + - If no relation of a type is stated, output an empty list for that field. + - Every subject, object, and asset must be one of the provided scene-intake + asset names. + - The top-level object must contain only object_relations, table_constraints, + and object_layouts. + - Do not output anchor or inferred table-region fields. + - Output JSON only. Do not include markdown or explanations outside JSON. + + +user: | + Extract explicit text spatial constraints from this prompt. + + + $asset_names + + + + $text + diff --git a/embodichain/gen_sim/prompt2scene/prompts/data/unified_scene_gen.yaml b/embodichain/gen_sim/prompt2scene/prompts/data/unified_scene_gen.yaml new file mode 100644 index 000000000..22d33af32 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/prompts/data/unified_scene_gen.yaml @@ -0,0 +1,225 @@ +name: unified_scene_gen +version: 1 + +up_down_flip_check_system: | + + You are a careful 3D tabletop geometry orientation verifier. + + + + You will receive: + - Image A: the original tabletop scene photo. + - Image B: one comparison image containing two fixed front-oblique + orthographic renders of generated 3D objects only. Each render has a + visible numeric label. + + Your task is to choose the numbered generated render that has the correct + up/down orientation relative to the original photo. + + + + - Choose selected_number=1 when candidate 1 better matches the original + photo's visible object tops and support-facing sides. + - Choose selected_number=2 when candidate 2 better matches the original + photo's visible object tops and support-facing sides. + - Do not request a yaw rotation around the vertical axis. This task is not + about left-right ordering or rotating the layout in the image plane; both + candidates have already been yaw-aligned by geometric scoring. + - The generated renders are not strict top views. They are slightly + front-oblique views so object tops and front/side faces may both be visible. + - Ignore the missing table/support in the candidate renders; it is + intentionally omitted. + - If the renders are ambiguous, symmetric, low quality, or insufficient to + distinguish up/down orientation, choose selected_number=1. + - confidence must be a number from 0 to 1. + - reason must be concise and explain the visual evidence. + + + + { + "selected_number": 1, + "confidence": 0.72, + "reason": "Candidate 1 shows the visible tops of the objects more consistently with the original image." + } + + + + - Output JSON only. Do not include markdown or explanations outside JSON. + - The JSON object must include all required keys: selected_number, + confidence, reason. + - selected_number must be exactly 1 or 2. + + +up_down_flip_check_user: | + Compare the original scene photo with the numbered generated object-only + front-oblique comparison image. + + + Choose which generated render has the correct up/down orientation. Return + exactly one JSON object with: + - selected_number: 1 or 2 + - confidence: number from 0 to 1 + - reason: short string + + +asset_metric_scale_system: | + + You estimate plausible real-world tabletop object bounding-box dimensions + from semantic descriptions. + + + + Given an object name and description, output one plausible real-world + bounding-box dimension in centimeters. + + + + - The dimensions must be in centimeters. + - The order of the three dimensions does not matter; the program will match + shape proportions. + - Estimate the full real-world object bbox, not only the visible part. + - Use common tabletop object sizes when the description is generic. + - Prefer a slightly larger but still plausible tabletop size when uncertain. + - Use confidence to express semantic certainty, not visual certainty. + - Output JSON only. Do not include markdown or text outside JSON. + + + + { + "bbox_dims_cm": [18.0, 8.0, 5.0], + "confidence": 0.72, + "reason": "Typical compact tabletop item size." + } + + +asset_metric_scale_user: | + Estimate plausible real-world bounding-box dimensions for this object. + + + $object_name + + + + $object_description + + + Return exactly one JSON object with: + - bbox_dims_cm: one slightly generous plausible size, three positive numbers in centimeters + - confidence: number from 0 to 1 + - reason: short string + +image_metric_scale_system: | + + You estimate plausible real-world tabletop object bounding-box dimensions + from a labeled scene image and object descriptions. + + + + You will receive: + - One image with each object marked by a bounding box and its object name. + - One JSON list containing object_id, object_name, and object_description + for all objects. + + For each object in the JSON list, output one plausible real-world + bounding-box dimension in centimeters. + + + + - Output one entry for every object_id in the input JSON. + - Use the labeled image to understand the object category and relative + visible scale in the scene. + - Use object_name and object_description as semantic anchors. + - The dimensions must be in centimeters. + - The order of the three dimensions does not matter. + - Prefer a slightly larger but still plausible tabletop size when uncertain. + - Use confidence to express semantic certainty. + - Output JSON only. Do not include markdown or text outside JSON. + + + + { + "object_scales": [ + { + "object_id": "interact_cup_0", + "bbox_dims_cm": [8.0, 8.0, 12.0], + "confidence": 0.78, + "reason": "Typical tabletop cup size." + } + ] + } + + +image_metric_scale_user: | + Estimate real-world dimensions for every object in the JSON below. + + + $objects_json + + + The attached image has bbox + name labels matching object_name. Return exactly + one JSON object with: + - object_scales: list of objects, one for every input object_id + - object_id: copied exactly from input + - bbox_dims_cm: one slightly generous plausible size, three positive numbers in centimeters + - confidence: number from 0 to 1 + - reason: short string + +text_metric_scale_system: | + + You estimate plausible real-world tabletop object bounding-box dimensions + from a full text scene prompt and object descriptions. + + + + You will receive: + - The user's original scene text. + - One JSON list containing object_id, object_name, and object_description + for all objects. + + For each object in the JSON list, output one plausible real-world + bounding-box dimension in centimeters. + + + + - Output one entry for every object_id in the input JSON. + - Use the full scene text to infer intended object scale and context. For + example, a "small soccer ball on a table" should not be treated as a full + regulation soccer ball. + - Use object_name and object_description as semantic anchors. + - The dimensions must be in centimeters. + - The order of the three dimensions does not matter. + - Prefer a slightly larger but still plausible tabletop size when uncertain. + - Use confidence to express semantic certainty. + - Output JSON only. Do not include markdown or text outside JSON. + + + + { + "object_scales": [ + { + "object_id": "interact_small_soccer_ball_0", + "bbox_dims_cm": [6.0, 6.0, 6.0], + "confidence": 0.74, + "reason": "The scene text describes a small tabletop soccer ball." + } + ] + } + + +text_metric_scale_user: | + Estimate real-world dimensions for every object in the JSON below. + + + $user_text + + + + $objects_json + + + Return exactly one JSON object with: + - object_scales: list of objects, one for every input object_id + - object_id: copied exactly from input + - bbox_dims_cm: one slightly generous plausible size, three positive numbers in centimeters + - confidence: number from 0 to 1 + - reason: short string diff --git a/embodichain/gen_sim/prompt2scene/utils/__init__.py b/embodichain/gen_sim/prompt2scene/utils/__init__.py new file mode 100644 index 000000000..8378c49ac --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/utils/__init__.py @@ -0,0 +1,39 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from . import log +from embodichain.gen_sim.prompt2scene.utils.io import ( + image_to_data_url, + relative_path, + write_json, +) +from embodichain.gen_sim.prompt2scene.utils.log import ( + log_api_request_start, + log_info, + log_warning, +) + +__all__ = [ + "log", + "log_api_request_start", + "log_info", + "log_warning", + "image_to_data_url", + "relative_path", + "write_json", +] diff --git a/embodichain/gen_sim/prompt2scene/utils/io.py b/embodichain/gen_sim/prompt2scene/utils/io.py new file mode 100644 index 000000000..6057d1981 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/utils/io.py @@ -0,0 +1,66 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +import base64 +import json +from pathlib import Path +from typing import Any + +from embodichain.gen_sim.prompt2scene.utils.log import log_info + +__all__ = ["image_to_data_url", "relative_path", "write_json"] + + +def relative_path(path: str | Path, root: Path) -> str: + """Return ``path`` relative to ``root`` when it is contained by it.""" + resolved_path = Path(path) + try: + return str(resolved_path.relative_to(root)) + except ValueError: + return str(path) + + +def write_json(path: Path, payload: dict[str, Any]) -> None: + """Write a JSON payload with prompt2scene's default formatting. + + Args: + path: Output JSON file path. + payload: JSON-serializable dictionary payload. + """ + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text( + json.dumps(payload, indent=2, ensure_ascii=False) + "\n", + encoding="utf-8", + ) + if not path.is_file(): + raise FileNotFoundError(f"JSON output was not written: {path}") + log_info(f"Wrote JSON: {path}") + + +def image_to_data_url(image_path: Path) -> str: + """Return a base64 data URL for a local image file.""" + suffix_to_mime = { + ".jpg": "image/jpeg", + ".jpeg": "image/jpeg", + ".png": "image/png", + ".webp": "image/webp", + ".gif": "image/gif", + } + mime_type = suffix_to_mime.get(image_path.suffix.lower(), "image/png") + encoded = base64.b64encode(image_path.read_bytes()).decode("ascii") + return f"data:{mime_type};base64,{encoded}" diff --git a/embodichain/gen_sim/prompt2scene/utils/log.py b/embodichain/gen_sim/prompt2scene/utils/log.py new file mode 100644 index 000000000..47bdfa445 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/utils/log.py @@ -0,0 +1,62 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +import logging +from typing import Any + +__all__ = ["log_api_request_start", "log_info", "log_warning"] + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s [EmbodiChain %(levelname)s]: %(message)s", + datefmt="%H:%M:%S", +) + +_LOGGER = logging.getLogger(__name__) +_LOGGER.setLevel(logging.INFO) + + +def _format_message(level: str, message: str) -> str: + _ = level + return f"Prompt2Scene: {message}" + + +def log_info(message: str) -> None: + """Log an info message using the EmbodiChain log prefix.""" + _LOGGER.info(_format_message("INFO", message)) + + +def log_warning(message: str) -> None: + """Log a warning message using the EmbodiChain log prefix.""" + _LOGGER.warning(_format_message("WARNING", message)) + + +def log_api_request_start( + *, + step: str, + request: str, + attempt: int | None = None, + **details: Any, +) -> None: + """Log the start of an API request with a stable key order.""" + fields = [f"step={step}", f"request={request}"] + if attempt is not None: + fields.append(f"attempt={attempt}") + for key, value in details.items(): + fields.append(f"{key}={value}") + log_info("api request start " + " ".join(fields)) diff --git a/embodichain/gen_sim/prompt2scene/workflows/__init__.py b/embodichain/gen_sim/prompt2scene/workflows/__init__.py new file mode 100644 index 000000000..393b0022b --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/workflows/__init__.py @@ -0,0 +1,41 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from embodichain.gen_sim.prompt2scene.workflows.artifact_writer import ( + DEBUG_DIRNAME, + IMAGE_SEGMENTS_STEP, + IMAGE_SPATIAL_RELATIONS_STEP, + RAW_MODEL_OUTPUT_FILENAME, + SCENE_INTAKE_STEP, + STEP_RESULT_FILENAME, + TEXT_RELATIONS_STEP, + UNIFIED_SCENE_STEP, + WorkflowArtifactWriter, +) + +__all__ = [ + "DEBUG_DIRNAME", + "IMAGE_SEGMENTS_STEP", + "IMAGE_SPATIAL_RELATIONS_STEP", + "RAW_MODEL_OUTPUT_FILENAME", + "SCENE_INTAKE_STEP", + "STEP_RESULT_FILENAME", + "TEXT_RELATIONS_STEP", + "UNIFIED_SCENE_STEP", + "WorkflowArtifactWriter", +] diff --git a/embodichain/gen_sim/prompt2scene/workflows/artifact_writer.py b/embodichain/gen_sim/prompt2scene/workflows/artifact_writer.py new file mode 100644 index 000000000..6587ccbbc --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/workflows/artifact_writer.py @@ -0,0 +1,271 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from pathlib import Path +import re +from typing import Any + +from embodichain.gen_sim.prompt2scene.utils.io import write_json + +__all__ = [ + "DEBUG_DIRNAME", + "IMAGE_SEGMENTS_STEP", + "IMAGE_SPATIAL_RELATIONS_STEP", + "RAW_MODEL_OUTPUT_FILENAME", + "SCENE_INTAKE_STEP", + "STEP_RESULT_FILENAME", + "TEXT_RELATIONS_STEP", + "UNIFIED_SCENE_GEN_STEP", + "UNIFIED_SCENE_STEP", + "WorkflowArtifactWriter", + "debug_dir_path", + "debug_round_dir_path", + "next_debug_round_dir_path", + "next_debug_round_name", + "step_dir_path", + "step_result_path", + "write_debug_json", + "write_debug_round_json", + "write_next_raw_model_output", + "write_raw_model_output", + "write_step_result", +] + +STEP_RESULT_FILENAME = "result.json" +DEBUG_DIRNAME = "debug" +RAW_MODEL_OUTPUT_FILENAME = "raw_model_output.json" + +SCENE_INTAKE_STEP = "scene_intake" +IMAGE_SEGMENTS_STEP = "image_segments" +IMAGE_SPATIAL_RELATIONS_STEP = "image_spatial_relations" +TEXT_RELATIONS_STEP = "text_relations" +UNIFIED_SCENE_STEP = "unified_scene" +UNIFIED_SCENE_GEN_STEP = "unified_scene_gen" + +DEBUG_ROUND_PATTERN = re.compile(r"^round_(\d+)(?:_|$)") + + +def step_dir_path(output_root: Path, step_name: str) -> Path: + """Return the directory path for a pipeline step.""" + return output_root / step_name + + +def step_result_path(output_root: Path, step_name: str) -> Path: + """Return the final result JSON path for a pipeline step.""" + return step_dir_path(output_root, step_name) / STEP_RESULT_FILENAME + + +def debug_dir_path(output_root: Path, step_name: str) -> Path: + """Return the debug directory path for a pipeline step.""" + return step_dir_path(output_root, step_name) / DEBUG_DIRNAME + + +def debug_round_dir_path( + output_root: Path, + step_name: str, + round_name: str, +) -> Path: + """Return a debug subdirectory path for one model/tool round.""" + return debug_dir_path(output_root, step_name) / round_name + + +def next_debug_round_name( + output_root: Path, + step_name: str, + label: str | None = None, +) -> str: + """Return the next step-local debug round name.""" + debug_dir = debug_dir_path(output_root, step_name) + max_index = 0 + if debug_dir.is_dir(): + for path in debug_dir.iterdir(): + if not path.is_dir(): + continue + match = DEBUG_ROUND_PATTERN.match(path.name) + if match is not None: + max_index = max(max_index, int(match.group(1))) + round_name = f"round_{max_index + 1:03d}" + if label: + round_name = f"{round_name}_{_path_token(label)}" + return round_name + + +def next_debug_round_dir_path( + output_root: Path, + step_name: str, + label: str | None = None, +) -> Path: + """Return the next step-local debug round directory path.""" + return debug_round_dir_path( + output_root, + step_name, + next_debug_round_name(output_root, step_name, label), + ) + + +def write_step_result( + output_root: Path, + step_name: str, + payload: dict[str, Any], +) -> Path: + """Write a step's final result JSON and return its path.""" + path = step_result_path(output_root, step_name) + write_json(path, payload) + return path + + +def write_debug_json( + output_root: Path, + step_name: str, + round_name: str, + filename: str, + payload: dict[str, Any], +) -> Path: + """Write a debug JSON file under one step debug round.""" + path = debug_round_dir_path(output_root, step_name, round_name) / filename + write_json(path, payload) + return path + + +def write_debug_round_json( + debug_round_dir: Path, + filename: str, + payload: dict[str, Any], +) -> Path: + """Write a debug JSON file under an already selected debug round directory.""" + path = debug_round_dir / filename + write_json(path, payload) + return path + + +def write_raw_model_output( + output_root: Path, + step_name: str, + round_name: str, + payload: dict[str, Any], +) -> Path: + """Write one raw structured model output under a step debug round.""" + return write_debug_json( + output_root, + step_name, + round_name, + RAW_MODEL_OUTPUT_FILENAME, + payload, + ) + + +def write_next_raw_model_output( + output_root: Path, + step_name: str, + payload: dict[str, Any], + label: str | None = None, +) -> Path: + """Write raw model output under the next step-local debug round.""" + round_name = next_debug_round_name(output_root, step_name, label) + return write_raw_model_output(output_root, step_name, round_name, payload) + + +class WorkflowArtifactWriter: + """Write workflow artifacts under a fixed step directory.""" + + def __init__(self, output_root: Path, step_name: str) -> None: + self._output_root = output_root + self._step_name = step_name + + @property + def output_root(self) -> Path: + return self._output_root + + @property + def step_name(self) -> str: + return self._step_name + + @property + def step_dir(self) -> Path: + return step_dir_path(self._output_root, self._step_name) + + @property + def debug_dir(self) -> Path: + return debug_dir_path(self._output_root, self._step_name) + + @property + def result_path(self) -> Path: + return step_result_path(self._output_root, self._step_name) + + def next_debug_round_name(self, label: str | None = None) -> str: + """Return the next debug round name for this step.""" + return next_debug_round_name(self._output_root, self._step_name, label) + + def next_debug_round_dir(self, label: str | None = None) -> Path: + """Return the next debug round directory for this step.""" + return next_debug_round_dir_path(self._output_root, self._step_name, label) + + def debug_round_dir(self, round_name: str) -> Path: + """Return one debug round directory under this step.""" + return debug_round_dir_path(self._output_root, self._step_name, round_name) + + def write_step_result(self, payload: dict[str, Any]) -> Path: + """Write the step's final result JSON.""" + return write_step_result(self._output_root, self._step_name, payload) + + def write_debug_round_json( + self, + *, + round_name: str, + filename: str, + payload: dict[str, Any], + ) -> Path: + """Write a JSON artifact inside one named debug round.""" + return write_debug_round_json( + self.debug_round_dir(round_name), + filename=filename, + payload=payload, + ) + + def write_raw_model_output( + self, + *, + round_name: str, + payload: dict[str, Any], + ) -> Path: + """Write a raw model output into one named debug round.""" + return write_raw_model_output( + self._output_root, + self._step_name, + round_name, + payload, + ) + + def write_next_raw_model_output( + self, + *, + payload: dict[str, Any], + label: str | None = None, + ) -> Path: + """Write a raw model output into the next available debug round.""" + return write_next_raw_model_output( + self._output_root, + self._step_name, + payload, + label=label, + ) + + +def _path_token(value: str) -> str: + token = "".join(character if character.isalnum() else "_" for character in value) + return token.strip("_")[:80] or "round" diff --git a/embodichain/gen_sim/prompt2scene/workflows/attempt_state.py b/embodichain/gen_sim/prompt2scene/workflows/attempt_state.py new file mode 100644 index 000000000..15407e78e --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/workflows/attempt_state.py @@ -0,0 +1,30 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from typing import TypedDict + +__all__ = ["AttemptState"] + + +class AttemptState(TypedDict): + """Common retry/error fields for one model-call stage.""" + + attempt_count: int + max_attempts: int + last_error: str | None + errors: list[str] diff --git a/embodichain/gen_sim/prompt2scene/workflows/image_relations/__init__.py b/embodichain/gen_sim/prompt2scene/workflows/image_relations/__init__.py new file mode 100644 index 000000000..ab49ab724 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/workflows/image_relations/__init__.py @@ -0,0 +1,24 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from embodichain.gen_sim.prompt2scene.workflows.image_relations.graph import ( + build_image_relations_graph, + run_image_relations, +) + +__all__ = ["build_image_relations_graph", "run_image_relations"] diff --git a/embodichain/gen_sim/prompt2scene/workflows/image_relations/graph.py b/embodichain/gen_sim/prompt2scene/workflows/image_relations/graph.py new file mode 100644 index 000000000..ff67f3a03 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/workflows/image_relations/graph.py @@ -0,0 +1,189 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from pathlib import Path +from typing import Any + +from langgraph.graph import END, StateGraph + +from embodichain.gen_sim.prompt2scene.llms import ( + OpenAICompatibleLLMCfg, + build_chat_model, +) +from embodichain.gen_sim.prompt2scene.utils import log +from embodichain.gen_sim.prompt2scene.workflows.stage_errors import ( + format_result_missing_error, +) +from embodichain.gen_sim.prompt2scene.workflows.image_relations.schema import ( + ImageRelationSpec, +) +from embodichain.gen_sim.prompt2scene.workflows.image_relations.nodes import ( + call_vlm_filter_initial_segments_node, + call_vlm_spatial_layout_node, + normalize_asset_segments_node, + prepare_segmentation_input_node, + retry_missing_by_candidates_node, + segment_table_node, + segment_by_name_node, +) +from embodichain.gen_sim.prompt2scene.workflows.scene_intake.schema import ( + SceneIntakeSpec, +) +from embodichain.gen_sim.prompt2scene.workflows.request import Prompt2SceneInput +from embodichain.gen_sim.prompt2scene.workflows.image_relations.state import ( + ImageRelationsState, +) + +__all__ = ["build_image_relations_graph", "run_image_relations"] + + +def route_after_filter_extra_instances(state: ImageRelationsState) -> str: + """Route to retry or continue after VLM extra-instance filtering.""" + if state["last_error"] is None: + return "continue" + if state["attempt_count"] < state["max_attempts"]: + return "retry" + return "continue" + + +def route_after_spatial_layout(state: ImageRelationsState) -> str: + """Route to retry or finish after spatial-layout extraction.""" + if state["last_error"] is None: + return "end" + if state["attempt_count"] < state["max_attempts"]: + return "retry" + return "end" + + +def build_image_relations_graph(llm: Any) -> Any: + """Build the fixed LangGraph image asset segmentation workflow.""" + graph = StateGraph(ImageRelationsState) + graph.add_node("prepare_segmentation_input", prepare_segmentation_input_node) + graph.add_node("segment_by_name", segment_by_name_node) + graph.add_node( + "call_vlm_filter_initial_segments", + lambda state: call_vlm_filter_initial_segments_node(state, llm=llm), + ) + graph.add_node( + "retry_missing_by_candidates", + lambda state: retry_missing_by_candidates_node(state, llm=llm), + ) + graph.add_node("normalize_asset_segments", normalize_asset_segments_node) + graph.add_node( + "segment_table", + lambda state: segment_table_node(state, llm=llm), + ) + graph.add_node( + "call_vlm_spatial_layout", + lambda state: call_vlm_spatial_layout_node(state, llm=llm), + ) + + graph.set_entry_point("prepare_segmentation_input") + graph.add_edge("prepare_segmentation_input", "segment_by_name") + graph.add_edge("segment_by_name", "call_vlm_filter_initial_segments") + graph.add_conditional_edges( + "call_vlm_filter_initial_segments", + route_after_filter_extra_instances, + { + "retry": "call_vlm_filter_initial_segments", + "continue": "retry_missing_by_candidates", + }, + ) + graph.add_edge("retry_missing_by_candidates", "normalize_asset_segments") + graph.add_edge("normalize_asset_segments", "segment_table") + graph.add_edge("segment_table", "call_vlm_spatial_layout") + graph.add_conditional_edges( + "call_vlm_spatial_layout", + route_after_spatial_layout, + { + "retry": "call_vlm_spatial_layout", + "end": END, + }, + ) + return graph.compile() + + +def run_image_relations( + request: Prompt2SceneInput, + *, + scene_intake: SceneIntakeSpec, + llm_cfg: OpenAICompatibleLLMCfg, + output_root: Path, +) -> ImageRelationSpec: + """Run image asset segmentation alignment for one prompt2scene request.""" + llm = build_chat_model(llm_cfg) + graph = build_image_relations_graph(llm) + result = graph.invoke( + { + "request": request, + "scene_intake": scene_intake, + "output_root": output_root, + "segment_groups": [], + "raw_model_output": None, + "image_relations": None, + "attempt_count": 0, + "max_attempts": llm_cfg.max_attempts, + "last_error": None, + "errors": [], + } + ) + + image_relations = result.get("image_relations") + if ( + image_relations is not None + and image_relations.status == "ok" + and image_relations.anchor is not None + ): + return image_relations + if image_relations is not None and image_relations.status == "ok": + error = format_result_missing_error( + "Image relations", + "spatial layout", + attempt_count=result.get("attempt_count", 0), + last_error=result.get("last_error"), + errors=result.get("errors", []), + ) + log.log_warning(error) + raise RuntimeError(error) + if image_relations is not None: + failed_groups = [ + group.to_manifest() + for group in image_relations.groups + if group.status != "ok" + ] + if ( + image_relations.table_group is not None + and image_relations.table_group.status != "ok" + ): + failed_groups.append(image_relations.table_group.to_manifest()) + error = ( + "Image relations failed to align all image segments. " + f"Failed groups: {failed_groups}" + ) + log.log_warning(error) + raise RuntimeError(error) + + error = format_result_missing_error( + "Image relations", + "ImageRelationSpec", + attempt_count=result.get("attempt_count", 0), + last_error=result.get("last_error"), + errors=result.get("errors", []), + ) + log.log_warning(error) + raise RuntimeError(error) diff --git a/embodichain/gen_sim/prompt2scene/workflows/image_relations/nodes.py b/embodichain/gen_sim/prompt2scene/workflows/image_relations/nodes.py new file mode 100644 index 000000000..ab8b69522 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/workflows/image_relations/nodes.py @@ -0,0 +1,511 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from pathlib import Path +from typing import Any + +from embodichain.gen_sim.prompt2scene.agent_tools.clients.image_segmentation_client import ( + decode_rle_mask, + draw_numbered_masks, +) +from embodichain.gen_sim.prompt2scene.workflows.image_relations.schema import ( + ImageAssetSegment, + ImageRelationGroup, + ImageRelationSpec, +) +from embodichain.gen_sim.prompt2scene.workflows.request import InputKind +from embodichain.gen_sim.prompt2scene.workflows.image_relations.schema import ( + FILTER_EXTRA_INSTANCES_JSON_SCHEMA, + SPATIAL_LAYOUT_JSON_SCHEMA, +) +from embodichain.gen_sim.prompt2scene.utils import ( + log_api_request_start, + log, +) +from embodichain.gen_sim.prompt2scene.workflows.artifact_writer import ( + IMAGE_SEGMENTS_STEP, + IMAGE_SPATIAL_RELATIONS_STEP, + WorkflowArtifactWriter, +) +from embodichain.gen_sim.prompt2scene.workflows.image_relations.utils import ( + append_unique, + apply_spatial_layout_output, + asset_bbox_label, + draw_labeled_bboxes, + expand_asset_ids, + filter_group_segments_with_vlm, + filter_segments_with_vlm, + merge_non_overlapping_segments, + prompt_text, + path_token, + require_image_path, + segment_prompt, + segments_from_response, +) +from embodichain.gen_sim.prompt2scene.workflows.image_relations.prompts import ( + build_filter_extra_instances_messages, + build_spatial_layout_messages, +) +from embodichain.gen_sim.prompt2scene.workflows.image_relations.state import ( + ImageRelationsState, +) +from embodichain.gen_sim.prompt2scene.workflows.llm_output import ( + call_structured_json_model_step, + is_model_output_error, +) +from embodichain.gen_sim.prompt2scene.workflows.stage_errors import ( + format_attempt_error, +) + +__all__ = [ + "call_vlm_filter_extra_instances_node", + "call_vlm_filter_initial_segments_node", + "call_vlm_spatial_layout_node", + "normalize_asset_segments_node", + "prepare_segmentation_input_node", + "retry_missing_by_candidates_node", + "segment_table_node", + "segment_by_name_node", +] + +def prepare_segmentation_input_node(state: ImageRelationsState) -> dict[str, object]: + """Prepare scene-intake asset groups for class-level segmentation.""" + request = state["request"] + if request.input_kind != InputKind.IMAGE or request.image_path is None: + raise ValueError("Image relations requires an image input.") + + segment_groups = [] + for asset in state["scene_intake"].assets: + group = { + "name": asset.name, + "description": asset.description, + "asset_ids": expand_asset_ids(asset.id, asset.count), + "class_candidate": list(asset.class_candidate), + "segments": [], + "tried_prompts": [], + "debug_images": [], + "status": "pending", + "error": None, + "expected_count": asset.count, + } + segment_groups.append(group) + return {"segment_groups": segment_groups} + + +def segment_by_name_node(state: ImageRelationsState) -> dict[str, object]: + """Run SAM3 once per object name.""" + image_path = require_image_path(state) + segment_groups = [] + for group in state["segment_groups"]: + prompt = prompt_text(group["name"]) + response = segment_prompt(image_path=image_path, prompt=prompt) + group = dict(group) + group["tried_prompts"] = append_unique(group["tried_prompts"], prompt) + group["segments"] = segments_from_response( + group=group, + response=response, + source_prompt=prompt, + ) + segment_groups.append(group) + return {"segment_groups": segment_groups} + + +def call_vlm_filter_extra_instances_node( + state: ImageRelationsState, + *, + llm: Any, +) -> dict[str, object]: + """Compatibility wrapper for the initial VLM segment filter.""" + return call_vlm_filter_initial_segments_node(state, llm=llm) + + +def call_vlm_filter_initial_segments_node( + state: ImageRelationsState, + *, + llm: Any, +) -> dict[str, object]: + """Ask VLM to remove wrong masks from initial name-based SAM3 output.""" + return filter_segments_with_vlm(state=state, llm=llm, stage="initial") +def retry_missing_by_candidates_node( + state: ImageRelationsState, + *, + llm: Any, +) -> dict[str, object]: + """Use remaining class candidates to add missing segment instances.""" + image_path = require_image_path(state) + artifact_writer = WorkflowArtifactWriter(state["output_root"], IMAGE_SEGMENTS_STEP) + segment_groups = [] + for group in state["segment_groups"]: + group = dict(group) + segments = group["segments"] + expected_count = group["expected_count"] + for candidate_name in group["class_candidate"][1:]: + if len(segments) >= expected_count: + break + prompt = prompt_text(candidate_name) + if prompt in group["tried_prompts"]: + continue + response = segment_prompt(image_path=image_path, prompt=prompt) + group["tried_prompts"] = append_unique(group["tried_prompts"], prompt) + new_segments = segments_from_response( + group=group, + response=response, + source_prompt=prompt, + ) + new_segments = filter_group_segments_with_vlm( + llm=llm, + image_path=image_path, + artifact_writer=artifact_writer, + group=group, + segments=new_segments, + stage=f"fallback_{path_token(prompt)}", + ) + segments = merge_non_overlapping_segments( + existing=segments, + incoming=new_segments, + limit=expected_count, + ) + if len(segments) < expected_count: + description_prompt = str(group.get("description") or "").strip() + if description_prompt and description_prompt not in group["tried_prompts"]: + response = segment_prompt( + image_path=image_path, + prompt=description_prompt, + ) + group["tried_prompts"] = append_unique( + group["tried_prompts"], + description_prompt, + ) + new_segments = segments_from_response( + group=group, + response=response, + source_prompt=description_prompt, + ) + new_segments = filter_group_segments_with_vlm( + llm=llm, + image_path=image_path, + artifact_writer=artifact_writer, + group=group, + segments=new_segments, + stage="fallback_description", + ) + segments = merge_non_overlapping_segments( + existing=segments, + incoming=new_segments, + limit=expected_count, + ) + group["segments"] = segments + segment_groups.append(group) + return {"segment_groups": segment_groups} + + +def normalize_asset_segments_node(state: ImageRelationsState) -> dict[str, object]: + """Assign final segments to scene-intake asset IDs.""" + image_path = require_image_path(state) + asset_segments: list[ImageAssetSegment] = [] + relation_groups: list[ImageRelationGroup] = [] + status = "ok" + + for group in state["segment_groups"]: + expected_count = group["expected_count"] + segments = group["segments"] + group_status = "ok" + error = None + if len(segments) < expected_count: + group_status = "failed" + error = "missing_segments" + status = "failed" + elif len(segments) > expected_count: + group_status = "failed" + error = "extra_segments" + status = "failed" + + relation_groups.append( + ImageRelationGroup( + name=group["name"], + expected_count=expected_count, + detected_count=len(segments), + status=group_status, + tried_prompts=list(group["tried_prompts"]), + asset_ids=list(group["asset_ids"]), + debug_images=list(group["debug_images"]), + error=error, + ) + ) + + if group_status != "ok": + continue + for asset_id, segment in zip(group["asset_ids"], segments): + asset_segments.append( + ImageAssetSegment( + asset_id=asset_id, + name=group["name"], + segment_id=segment["segment_id"], + bbox_xyxy=list(segment["bbox_xyxy"]), + score=float(segment["score"]), + source_prompt=segment["source_prompt"], + mask_rle=segment.get("mask_rle"), + ) + ) + + bbox_name_image_path = None + if status == "ok": + artifact_writer = WorkflowArtifactWriter( + state["output_root"], + IMAGE_SEGMENTS_STEP, + ) + bbox_name_image_path = str( + draw_labeled_bboxes( + image_path=image_path, + boxes=[ + { + "bbox_xyxy": segment.bbox_xyxy, + "label": asset_bbox_label(segment.asset_id), + } + for segment in asset_segments + ], + output_path=artifact_writer.step_dir / "asset_segments_bbox_name.png", + ) + ) + + image_relations = ImageRelationSpec( + status=status, + image_path=str(image_path), + asset_segments=asset_segments, + groups=relation_groups, + bbox_name_image_path=bbox_name_image_path, + ) + WorkflowArtifactWriter( + state["output_root"], + IMAGE_SEGMENTS_STEP, + ).write_step_result(image_relations.to_segmentation_manifest()) + return {"image_relations": image_relations} + + +def segment_table_node( + state: ImageRelationsState, + *, + llm: Any, +) -> dict[str, object]: + """Segment the table/support target after object segmentation is complete.""" + image_relations = state["image_relations"] + if image_relations is None or image_relations.status != "ok": + return {} + + image_path = require_image_path(state) + table = state["scene_intake"].table + artifact_writer = WorkflowArtifactWriter(state["output_root"], IMAGE_SEGMENTS_STEP) + group = { + "name": table.name, + "description": table.description, + "asset_ids": [table.id], + "class_candidate": list(table.class_candidate), + "segments": [], + "tried_prompts": [], + "debug_images": [], + "status": "pending", + "error": None, + "expected_count": 1, + } + segments: list[dict[str, Any]] = [] + + for prompt in _table_segmentation_prompts(group): + if len(segments) >= 1: + break + response = segment_prompt(image_path=image_path, prompt=prompt) + group["tried_prompts"] = append_unique(group["tried_prompts"], prompt) + new_segments = segments_from_response( + group=group, + response=response, + source_prompt=prompt, + ) + _write_table_candidate_debug_image( + image_path=image_path, + artifact_writer=artifact_writer, + group=group, + segments=new_segments, + stage=f"table_{path_token(prompt)}", + ) + selected_segment = _select_largest_table_segment(new_segments) + if selected_segment is not None: + segments = [selected_segment] + + group_status = "ok" if len(segments) == 1 else "failed" + error = None if group_status == "ok" else "missing_table_segment" + table_group = ImageRelationGroup( + name=group["name"], + expected_count=1, + detected_count=len(segments), + status=group_status, + tried_prompts=list(group["tried_prompts"]), + asset_ids=[table.id], + debug_images=list(group["debug_images"]), + error=error, + ) + table_segment = None + if group_status == "ok": + segment = segments[0] + table_segment = ImageAssetSegment( + asset_id=table.id, + name=table.name, + segment_id=segment["segment_id"], + bbox_xyxy=list(segment["bbox_xyxy"]), + score=float(segment["score"]), + source_prompt=segment["source_prompt"], + mask_rle=segment.get("mask_rle"), + ) + + updated_image_relations = ImageRelationSpec( + status="ok" if group_status == "ok" else "failed", + image_path=image_relations.image_path, + asset_segments=image_relations.asset_segments, + groups=image_relations.groups, + table_segment=table_segment, + table_group=table_group, + bbox_name_image_path=image_relations.bbox_name_image_path, + anchor=image_relations.anchor, + x_order=image_relations.x_order, + y_order=image_relations.y_order, + asset_layouts=image_relations.asset_layouts, + ) + artifact_writer.write_step_result(updated_image_relations.to_segmentation_manifest()) + return {"image_relations": updated_image_relations} + + +def _table_segmentation_prompts(group: dict[str, Any]) -> list[str]: + """Return table/support segmentation prompts in object-style fallback order.""" + prompts = [prompt_text(group["name"])] + for candidate_name in group["class_candidate"][1:]: + prompts.append(prompt_text(candidate_name)) + description_prompt = str(group.get("description") or "").strip() + if description_prompt: + prompts.append(description_prompt) + + unique_prompts: list[str] = [] + for prompt in prompts: + if prompt and prompt not in unique_prompts: + unique_prompts.append(prompt) + return unique_prompts + + +def _write_table_candidate_debug_image( + *, + image_path: Path, + artifact_writer: WorkflowArtifactWriter, + group: dict[str, Any], + segments: list[dict[str, Any]], + stage: str, +) -> None: + """Write table/support candidate mask debug image without VLM filtering.""" + if not segments: + return + round_name = artifact_writer.next_debug_round_name(label=f"{stage}_{group['name']}") + round_dir = artifact_writer.debug_round_dir(round_name) + debug_image_path = draw_numbered_masks( + image_path=image_path, + segments=segments, + output_path=round_dir / "mask.png", + ) + group["debug_images"] = append_unique( + group["debug_images"], + str(debug_image_path), + ) + + +def _select_largest_table_segment( + segments: list[dict[str, Any]], +) -> dict[str, Any] | None: + """Select the largest SAM3 table/support candidate without VLM filtering.""" + if not segments: + return None + return max(segments, key=_segment_area) + + +def _segment_area(segment: dict[str, Any]) -> float: + mask_rle = segment.get("mask_rle") + if mask_rle is not None: + try: + mask = decode_rle_mask(mask_rle).convert("L") + histogram = mask.histogram() + return float(sum(count for value, count in enumerate(histogram) if value)) + except Exception: + pass + x1, y1, x2, y2 = segment["bbox_xyxy"] + return max(0.0, float(x2) - float(x1)) * max(0.0, float(y2) - float(y1)) + + +def call_vlm_spatial_layout_node( + state: ImageRelationsState, + *, + llm: Any, +) -> dict[str, object]: + """Ask VLM for object ordering, anchor grid, and per-object layout states.""" + image_relations = state["image_relations"] + if image_relations is None or image_relations.status != "ok": + return {} + if image_relations.bbox_name_image_path is None: + raise ValueError("Image spatial layout requires bbox_name_image_path.") + + attempt_count = state["attempt_count"] + 1 + asset_ids = [segment.asset_id for segment in image_relations.asset_segments] + artifact_writer = WorkflowArtifactWriter( + state["output_root"], + IMAGE_SPATIAL_RELATIONS_STEP, + ) + messages = build_spatial_layout_messages( + bbox_name_image_path=Path(image_relations.bbox_name_image_path), + asset_ids=asset_ids, + ) + + try: + log_api_request_start( + step=IMAGE_SPATIAL_RELATIONS_STEP, + request="spatial_layout", + attempt=attempt_count, + ) + raw_model_output = call_structured_json_model_step( + llm=llm, + schema=SPATIAL_LAYOUT_JSON_SCHEMA, + messages=messages, + context="Image spatial layout", + step_name=IMAGE_SPATIAL_RELATIONS_STEP, + output_root=None, + attempt_count=attempt_count, + raw_output_label="spatial_layout", + artifact_writer=artifact_writer, + ) + updated_image_relations = apply_spatial_layout_output( + image_relations=image_relations, + raw_model_output=raw_model_output, + ) + artifact_writer.write_step_result(updated_image_relations.to_spatial_manifest()) + except Exception as exc: + if is_model_output_error(exc) or isinstance(exc, ValueError): + error = format_attempt_error("Image relations spatial layout", attempt_count, exc) + log.log_warning(error) + return { + "attempt_count": attempt_count, + "last_error": error, + "errors": state["errors"] + [error], + } + raise + return { + "attempt_count": attempt_count, + "image_relations": updated_image_relations, + "last_error": None, + } diff --git a/embodichain/gen_sim/prompt2scene/workflows/image_relations/prompts.py b/embodichain/gen_sim/prompt2scene/workflows/image_relations/prompts.py new file mode 100644 index 000000000..f974f442e --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/workflows/image_relations/prompts.py @@ -0,0 +1,113 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from pathlib import Path +from typing import Any + +from embodichain.gen_sim.prompt2scene.prompts import render_prompt +from embodichain.gen_sim.prompt2scene.utils.io import image_to_data_url + +__all__ = [ + "build_filter_extra_instances_messages", + "build_spatial_layout_messages", +] + +IMAGE_RELATIONS_PROMPT_NAME = "image_relations.yaml" + + +def build_filter_extra_instances_messages( + *, + debug_image_path: Path, + name: str, + description: str, + expected_count: int, + class_candidate: list[str], +) -> list[dict[str, Any]]: + """Build LangChain-compatible messages for VLM extra-mask filtering.""" + return [ + { + "role": "system", + "content": render_prompt( + IMAGE_RELATIONS_PROMPT_NAME, + prompt_key="filter_extra_instances_system", + ), + }, + { + "role": "user", + "content": [ + { + "type": "text", + "text": render_prompt( + IMAGE_RELATIONS_PROMPT_NAME, + { + "name": name.replace("_", " "), + "description": description, + "expected_count": str(expected_count), + "class_candidate": ", ".join( + candidate.replace("_", " ") + for candidate in class_candidate + ), + }, + prompt_key="filter_extra_instances_user", + ), + }, + { + "type": "image_url", + "image_url": {"url": image_to_data_url(debug_image_path)}, + }, + ], + }, + ] + + +def build_spatial_layout_messages( + *, + bbox_name_image_path: Path, + asset_ids: list[str], +) -> list[dict[str, Any]]: + """Build messages for VLM spatial ordering and object-state extraction.""" + return [ + { + "role": "system", + "content": render_prompt( + IMAGE_RELATIONS_PROMPT_NAME, + prompt_key="spatial_layout_system", + ), + }, + { + "role": "user", + "content": [ + { + "type": "text", + "text": render_prompt( + IMAGE_RELATIONS_PROMPT_NAME, + { + "asset_ids": "\n".join( + f"- {asset_id}" for asset_id in asset_ids + ), + }, + prompt_key="spatial_layout_user", + ), + }, + { + "type": "image_url", + "image_url": {"url": image_to_data_url(bbox_name_image_path)}, + }, + ], + }, + ] diff --git a/embodichain/gen_sim/prompt2scene/workflows/image_relations/schema.py b/embodichain/gen_sim/prompt2scene/workflows/image_relations/schema.py new file mode 100644 index 000000000..500f7c702 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/workflows/image_relations/schema.py @@ -0,0 +1,250 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any + +from embodichain.gen_sim.prompt2scene.workflows.spatial import GRID_VALUE_LIST + +__all__ = [ + "FILTER_EXTRA_INSTANCES_JSON_SCHEMA", + "ImageAnchor", + "ImageAssetLayout", + "ImageAssetSegment", + "ImageRelationGroup", + "ImageRelationSpec", + "SPATIAL_LAYOUT_JSON_SCHEMA", +] + +FILTER_EXTRA_INSTANCES_JSON_SCHEMA: dict[str, Any] = { + "title": "FilterExtraImageInstancesOutput", + "type": "object", + "additionalProperties": False, + "properties": { + "extra_instance_numbers": { + "type": "array", + "description": "1-based mask numbers that should be removed.", + "items": {"type": "integer", "minimum": 1}, + }, + "reason": { + "type": "string", + "description": "Brief reason for the removal decision.", + }, + }, + "required": ["extra_instance_numbers", "reason"], +} + +SPATIAL_LAYOUT_JSON_SCHEMA: dict[str, Any] = { + "title": "ImageSpatialLayoutOutput", + "type": "object", + "additionalProperties": False, + "properties": { + "anchor": { + "type": "object", + "additionalProperties": False, + "properties": { + "asset_id": {"type": "string", "minLength": 1}, + "grid": { + "type": "string", + "enum": GRID_VALUE_LIST, + }, + "reason": {"type": "string"}, + }, + "required": ["asset_id", "grid", "reason"], + }, + "x_order": { + "type": "array", + "description": "Asset-id groups ordered from left to right.", + "items": { + "type": "array", + "items": {"type": "string", "minLength": 1}, + "minItems": 1, + }, + "minItems": 1, + }, + "y_order": { + "type": "array", + "description": "Asset-id groups ordered from front to back.", + "items": { + "type": "array", + "items": {"type": "string", "minLength": 1}, + "minItems": 1, + }, + "minItems": 1, + }, + "asset_states": { + "type": "array", + "items": { + "type": "object", + "additionalProperties": True, + "properties": { + "asset_id": {"type": "string", "minLength": 1}, + "is_arbitrary_layout": {"type": "boolean"}, + "reason": {"type": "string", "minLength": 1}, + }, + "required": [ + "asset_id", + "is_arbitrary_layout", + "reason", + ], + }, + }, + }, + "required": ["anchor", "x_order", "y_order", "asset_states"], +} + + +@dataclass(frozen=True) +class ImageAssetSegment: + """Image segmentation result aligned to one scene-intake asset.""" + + asset_id: str + name: str + segment_id: str + bbox_xyxy: list[float] + score: float + source_prompt: str + mask_rle: dict[str, Any] | None = None + + def to_manifest(self) -> dict[str, Any]: + """Convert the segment to JSON-safe data.""" + return { + "asset_id": self.asset_id, + "name": self.name, + "segment_id": self.segment_id, + "bbox_xyxy": list(self.bbox_xyxy), + "score": self.score, + "source_prompt": self.source_prompt, + "mask_rle": self.mask_rle, + } + + +@dataclass(frozen=True) +class ImageRelationGroup: + """Segmentation alignment status for assets sharing one object name.""" + + name: str + expected_count: int + detected_count: int + status: str + tried_prompts: list[str] = field(default_factory=list) + asset_ids: list[str] = field(default_factory=list) + debug_images: list[str] = field(default_factory=list) + error: str | None = None + + def to_manifest(self) -> dict[str, Any]: + """Convert the group to JSON-safe data.""" + return { + "name": self.name, + "expected_count": self.expected_count, + "detected_count": self.detected_count, + "status": self.status, + "tried_prompts": list(self.tried_prompts), + "asset_ids": list(self.asset_ids), + "debug_images": list(self.debug_images), + "error": self.error, + } + + +@dataclass(frozen=True) +class ImageAnchor: + """Anchor object used to place relative ordering onto the table grid.""" + + asset_id: str + grid: str + reason: str = "" + + def to_manifest(self) -> dict[str, Any]: + """Convert the anchor to JSON-safe data.""" + return { + "asset_id": self.asset_id, + "grid": self.grid, + "reason": self.reason, + } + + +@dataclass(frozen=True) +class ImageAssetLayout: + """Support state for one image asset instance.""" + + asset_id: str + is_arbitrary_layout: bool + reason: str = "" + + def to_manifest(self) -> dict[str, Any]: + """Convert the layout to JSON-safe data.""" + return { + "asset_id": self.asset_id, + "is_arbitrary_layout": self.is_arbitrary_layout, + "reason": self.reason, + } + + +@dataclass(frozen=True) +class ImageRelationSpec: + """Image asset segmentation alignment and spatial relations.""" + + status: str + image_path: str + asset_segments: list[ImageAssetSegment] + groups: list[ImageRelationGroup] + table_segment: ImageAssetSegment | None = None + table_group: ImageRelationGroup | None = None + bbox_name_image_path: str | None = None + anchor: ImageAnchor | None = None + x_order: list[list[str]] = field(default_factory=list) + y_order: list[list[str]] = field(default_factory=list) + asset_layouts: list[ImageAssetLayout] = field(default_factory=list) + + def to_manifest(self) -> dict[str, Any]: + """Convert the image relation spec to JSON-safe data.""" + manifest = self.to_segmentation_manifest() + manifest.update(self.to_spatial_manifest()) + return manifest + + def to_segmentation_manifest(self) -> dict[str, Any]: + """Convert only the segmentation alignment result to JSON-safe data.""" + return { + "image_path": self.image_path, + "bbox_name_image_path": self.bbox_name_image_path, + "asset_segments": [ + segment.to_manifest() for segment in self.asset_segments + ], + "groups": [group.to_manifest() for group in self.groups], + "table_segment": ( + self.table_segment.to_manifest() if self.table_segment else None + ), + "table_group": ( + self.table_group.to_manifest() if self.table_group else None + ), + } + + def to_spatial_manifest(self) -> dict[str, Any]: + """Convert only spatial relations and layout states to JSON-safe data.""" + return { + "image_path": self.image_path, + "bbox_name_image_path": self.bbox_name_image_path, + "anchor": self.anchor.to_manifest() if self.anchor else None, + "spatial_order": { + "left_to_right": [list(group) for group in self.x_order], + "front_to_back": [list(group) for group in self.y_order], + }, + "objects": [ + layout.to_manifest() for layout in self.asset_layouts + ], + } diff --git a/embodichain/gen_sim/prompt2scene/workflows/image_relations/state.py b/embodichain/gen_sim/prompt2scene/workflows/image_relations/state.py new file mode 100644 index 000000000..598530058 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/workflows/image_relations/state.py @@ -0,0 +1,42 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from pathlib import Path +from typing import Any + +from embodichain.gen_sim.prompt2scene.workflows.image_relations.schema import ( + ImageRelationSpec, +) +from embodichain.gen_sim.prompt2scene.workflows.request import Prompt2SceneInput +from embodichain.gen_sim.prompt2scene.workflows.scene_intake.schema import ( + SceneIntakeSpec, +) +from embodichain.gen_sim.prompt2scene.workflows.attempt_state import AttemptState + +__all__ = ["ImageRelationsState"] + + +class ImageRelationsState(AttemptState): + """LangGraph state for image asset segmentation alignment.""" + + request: Prompt2SceneInput + scene_intake: SceneIntakeSpec + output_root: Path + segment_groups: list[dict[str, Any]] + raw_model_output: dict[str, Any] | None + image_relations: ImageRelationSpec | None diff --git a/embodichain/gen_sim/prompt2scene/workflows/image_relations/utils.py b/embodichain/gen_sim/prompt2scene/workflows/image_relations/utils.py new file mode 100644 index 000000000..27e3b1b39 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/workflows/image_relations/utils.py @@ -0,0 +1,435 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from pathlib import Path +from typing import Any + +from embodichain.gen_sim.prompt2scene.agent_tools.clients.image_segmentation_client import ( + ImageSegmentationClient, + ImageSegmentationError, + ImageSegmentationServerRequest, + ImageSegmentationServerResponse, + bbox_iou, + draw_labeled_bboxes, + draw_numbered_masks, + is_usable_segmentation_candidate, + sort_segments_by_bbox, +) +from embodichain.gen_sim.prompt2scene.workflows.image_relations.schema import ( + FILTER_EXTRA_INSTANCES_JSON_SCHEMA, + ImageAnchor, + ImageAssetLayout, + ImageAssetSegment, + ImageRelationGroup, + ImageRelationSpec, + SPATIAL_LAYOUT_JSON_SCHEMA, +) +from embodichain.gen_sim.prompt2scene.workflows.spatial import ( + GRID_VALUES, + validate_exact_asset_id_coverage, +) +from embodichain.gen_sim.prompt2scene.utils import log_api_request_start, log +from embodichain.gen_sim.prompt2scene.workflows.artifact_writer import ( + IMAGE_SEGMENTS_STEP, + IMAGE_SPATIAL_RELATIONS_STEP, + RAW_MODEL_OUTPUT_FILENAME, + WorkflowArtifactWriter, +) +from embodichain.gen_sim.prompt2scene.workflows.image_relations.prompts import ( + build_filter_extra_instances_messages, + build_spatial_layout_messages, +) +from embodichain.gen_sim.prompt2scene.workflows.llm_output import ( + call_structured_json_model_step, + is_model_output_error, +) +from embodichain.gen_sim.prompt2scene.workflows.stage_errors import ( + format_attempt_error, +) + +__all__ = [ + "MAX_SEGMENT_RETRIES", + "OVERLAP_IOU_THRESHOLD", + "append_unique", + "apply_spatial_layout_output", + "asset_bbox_label", + "expand_asset_ids", + "filter_group_segments_with_vlm", + "filter_segments_with_vlm", + "merge_non_overlapping_segments", + "draw_labeled_bboxes", + "parse_anchor", + "parse_asset_states", + "parse_order_groups", + "path_token", + "prompt_text", + "remove_extra_numbered_segments", + "require_image_path", + "segment_prompt", + "segments_from_response", + "sort_segments_by_bbox", +] + +MAX_SEGMENT_RETRIES = 1 +OVERLAP_IOU_THRESHOLD = 0.5 + + +def require_image_path(state: dict[str, Any]) -> Path: + """Return the request image path or raise if the input is invalid.""" + image_path = state["request"].image_path + if image_path is None: + raise ValueError("Image relations requires request.image_path.") + return image_path + + +def prompt_text(name: str) -> str: + """Convert an asset name to a natural-language prompt.""" + return name.replace("_", " ") + + +def asset_bbox_label(asset_id: str) -> str: + """Convert an internal asset id into a display label.""" + prefix = "interact_" + return asset_id[len(prefix) :] if asset_id.startswith(prefix) else asset_id + + +def expand_asset_ids(asset_id: str, count: int) -> list[str]: + """Expand a grouped asset id into instance ids.""" + return [f"{asset_id}_{index}" for index in range(count)] + + +def path_token(value: str) -> str: + """Convert a label into a filesystem-safe token.""" + token = "".join(character if character.isalnum() else "_" for character in value) + return token.strip("_")[:80] or "prompt" + + +def append_unique(values: list[str], value: str) -> list[str]: + """Append a string only if it does not already exist in the list.""" + return values if value in values else values + [value] + + +def segment_prompt( + *, + image_path: Path, + prompt: str, +) -> ImageSegmentationServerResponse: + """Call the segmentation server with a single prompt.""" + client = ImageSegmentationClient() + log_api_request_start( + step=IMAGE_SEGMENTS_STEP, + request="sam3_segment", + prompt=prompt, + ) + result = client.segment( + ImageSegmentationServerRequest(prompt=prompt, image_path=image_path), + max_retries=MAX_SEGMENT_RETRIES, + ) + if isinstance(result, ImageSegmentationError): + log.log_warning(result.error_message) + raise RuntimeError(result.error_message) + return result + + +def segments_from_response( + *, + group: dict[str, Any], + response: ImageSegmentationServerResponse, + source_prompt: str, +) -> list[dict[str, Any]]: + """Convert segmentation server output into internal segment dicts.""" + segments = [] + for candidate in response.result.candidates: + if not is_usable_segmentation_candidate(candidate): + continue + segments.append( + { + "segment_id": f"{group['name']}_{len(segments)}", + "bbox_xyxy": list(candidate.bbox_xyxy), + "score": float(candidate.score), + "mask_rle": candidate.mask_rle, + "source_prompt": source_prompt, + } + ) + return sort_segments_by_bbox(segments) + + +def apply_spatial_layout_output( + *, + image_relations: ImageRelationSpec, + raw_model_output: dict[str, Any], +) -> ImageRelationSpec: + """Apply VLM spatial-layout output to an image-relations spec.""" + asset_ids = [segment.asset_id for segment in image_relations.asset_segments] + asset_id_set = set(asset_ids) + + anchor = parse_anchor(raw_model_output.get("anchor"), asset_id_set=asset_id_set) + x_order = parse_order_groups( + raw_model_output.get("x_order"), + asset_ids=asset_ids, + field_name="x_order", + ) + y_order = parse_order_groups( + raw_model_output.get("y_order"), + asset_ids=asset_ids, + field_name="y_order", + ) + state_by_asset_id = parse_asset_states( + raw_model_output.get("asset_states"), + asset_ids=asset_ids, + ) + asset_layouts = [ + ImageAssetLayout( + asset_id=asset_id, + is_arbitrary_layout=state_by_asset_id[asset_id]["is_arbitrary_layout"], + reason=state_by_asset_id[asset_id]["reason"], + ) + for asset_id in asset_ids + ] + return ImageRelationSpec( + status=image_relations.status, + image_path=image_relations.image_path, + asset_segments=image_relations.asset_segments, + groups=image_relations.groups, + table_segment=image_relations.table_segment, + table_group=image_relations.table_group, + bbox_name_image_path=image_relations.bbox_name_image_path, + anchor=anchor, + x_order=x_order, + y_order=y_order, + asset_layouts=asset_layouts, + ) + + +def parse_anchor(raw_anchor: Any, *, asset_id_set: set[str]) -> ImageAnchor: + """Parse and validate the anchor entry.""" + if not isinstance(raw_anchor, dict): + raise ValueError("anchor must be an object.") + asset_id = str(raw_anchor.get("asset_id") or "").strip() + grid = str(raw_anchor.get("grid") or "").strip() + reason = str(raw_anchor.get("reason") or "").strip() + if asset_id not in asset_id_set: + raise ValueError(f"anchor.asset_id is not a known asset: {asset_id!r}.") + if grid not in GRID_VALUES: + raise ValueError(f"anchor.grid is not valid: {grid!r}.") + return ImageAnchor(asset_id=asset_id, grid=grid, reason=reason) + + +def parse_order_groups( + raw_order: Any, + *, + asset_ids: list[str], + field_name: str, +) -> list[list[str]]: + """Parse ordered asset-id groups from VLM output.""" + if not isinstance(raw_order, list) or not raw_order: + raise ValueError(f"{field_name} must be a non-empty list.") + + groups: list[list[str]] = [] + flattened: list[str] = [] + for group_index, raw_group in enumerate(raw_order): + if not isinstance(raw_group, list) or not raw_group: + raise ValueError(f"{field_name}[{group_index}] must be a non-empty list.") + group: list[str] = [] + for raw_asset_id in raw_group: + asset_id = str(raw_asset_id).strip() + group.append(asset_id) + flattened.append(asset_id) + groups.append(group) + + validate_exact_asset_id_coverage( + values=flattened, + expected_asset_ids=asset_ids, + context=field_name, + ) + return groups + + +def parse_asset_states( + raw_asset_states: Any, + *, + asset_ids: list[str], +) -> dict[str, dict[str, Any]]: + """Parse per-asset layout state annotations.""" + if not isinstance(raw_asset_states, list): + raise ValueError("asset_states must be a list.") + + state_by_asset_id: dict[str, dict[str, Any]] = {} + for state_index, raw_state in enumerate(raw_asset_states): + if not isinstance(raw_state, dict): + raise ValueError(f"asset_states[{state_index}] must be an object.") + asset_id = str(raw_state.get("asset_id") or "").strip() + is_arbitrary_layout = raw_state.get("is_arbitrary_layout") + reason = str(raw_state.get("reason") or "").strip() + if not isinstance(is_arbitrary_layout, bool): + raise ValueError( + f"asset_states[{state_index}].is_arbitrary_layout must be boolean." + ) + if not reason: + raise ValueError(f"asset_states[{state_index}].reason must be non-empty.") + if asset_id in state_by_asset_id: + raise ValueError(f"asset_states has duplicate asset_id: {asset_id!r}.") + state_by_asset_id[asset_id] = { + "is_arbitrary_layout": is_arbitrary_layout, + "reason": reason, + } + + validate_exact_asset_id_coverage( + values=list(state_by_asset_id), + expected_asset_ids=asset_ids, + context="asset_states", + ) + return state_by_asset_id + + +def filter_group_segments_with_vlm( + *, + llm: Any, + image_path: Path, + artifact_writer: WorkflowArtifactWriter, + group: dict[str, Any], + segments: list[dict[str, Any]], + stage: str, +) -> list[dict[str, Any]]: + """Ask VLM to remove wrong or duplicate instances from one SAM3 result.""" + segments = sort_segments_by_bbox(segments) + if not segments: + return segments + + round_name = artifact_writer.next_debug_round_name(label=f"{stage}_{group['name']}") + round_dir = artifact_writer.debug_round_dir(round_name) + debug_image_path = draw_numbered_masks( + image_path=image_path, + segments=segments, + output_path=round_dir / "mask.png", + ) + group["debug_images"] = append_unique( + group["debug_images"], + str(debug_image_path), + ) + log_api_request_start( + step=IMAGE_SEGMENTS_STEP, + request=f"vlm_filter_{stage}", + debug_image=str(debug_image_path), + ) + messages = build_filter_extra_instances_messages( + debug_image_path=debug_image_path, + name=group["name"], + description=group["description"], + expected_count=group["expected_count"], + class_candidate=group["class_candidate"], + ) + raw_model_output = call_structured_json_model_step( + llm=llm, + schema=FILTER_EXTRA_INSTANCES_JSON_SCHEMA, + messages=messages, + context=f"Image relation {stage} segmentation filtering", + step_name=IMAGE_SEGMENTS_STEP, + output_root=None, + attempt_count=0, + raw_output_writer=lambda payload: artifact_writer.write_debug_round_json( + round_name=round_name, + filename=RAW_MODEL_OUTPUT_FILENAME, + payload=payload, + ), + ) + return remove_extra_numbered_segments( + segments=segments, + raw_model_output=raw_model_output, + ) + + +def filter_segments_with_vlm( + *, + state: dict[str, Any], + llm: Any, + stage: str, +) -> dict[str, object]: + """Filter all segment groups with VLM and return an updated state patch.""" + segment_groups = [] + attempt_count = state["attempt_count"] + 1 + image_path = require_image_path(state) + artifact_writer = WorkflowArtifactWriter(state["output_root"], IMAGE_SEGMENTS_STEP) + + try: + for group in state["segment_groups"]: + group = dict(group) + group["segments"] = filter_group_segments_with_vlm( + llm=llm, + image_path=image_path, + artifact_writer=artifact_writer, + group=group, + segments=group["segments"], + stage=stage, + ) + segment_groups.append(group) + except Exception as exc: + if is_model_output_error(exc) or isinstance(exc, ValueError): + error = format_attempt_error("Image relations VLM filter", attempt_count, exc) + log.log_warning(error) + return { + "attempt_count": attempt_count, + "last_error": error, + "errors": state["errors"] + [error], + } + raise + + return { + "attempt_count": attempt_count, + "segment_groups": segment_groups, + "last_error": None, + } + + +def remove_extra_numbered_segments( + *, + segments: list[dict[str, Any]], + raw_model_output: dict[str, Any], +) -> list[dict[str, Any]]: + """Remove numbered masks flagged as extra by the VLM.""" + extra_numbers = raw_model_output.get("extra_instance_numbers") + if not isinstance(extra_numbers, list): + raise ValueError("extra_instance_numbers must be a list.") + extra_indices = {int(number) - 1 for number in extra_numbers} + if any(index < 0 or index >= len(segments) for index in extra_indices): + raise ValueError("VLM returned an out-of-range extra mask number.") + kept = [ + segment for index, segment in enumerate(segments) if index not in extra_indices + ] + return kept + + +def merge_non_overlapping_segments( + *, + existing: list[dict[str, Any]], + incoming: list[dict[str, Any]], + limit: int, +) -> list[dict[str, Any]]: + """Merge non-overlapping segments until a limit is reached.""" + merged = list(existing) + for segment in sorted( + incoming, key=lambda item: float(item["score"]), reverse=True + ): + if len(merged) >= limit: + break + if all( + bbox_iou(segment["bbox_xyxy"], other["bbox_xyxy"]) < OVERLAP_IOU_THRESHOLD + for other in merged + ): + merged.append(segment) + return sort_segments_by_bbox(merged) diff --git a/embodichain/gen_sim/prompt2scene/workflows/llm_output.py b/embodichain/gen_sim/prompt2scene/workflows/llm_output.py new file mode 100644 index 000000000..bcc98bcbb --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/workflows/llm_output.py @@ -0,0 +1,285 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +import json +from pathlib import Path +from typing import Any, Callable + +from embodichain.gen_sim.prompt2scene.workflows.artifact_writer import ( + WorkflowArtifactWriter, + write_next_raw_model_output, +) + +__all__ = [ + "bind_structured_output", + "coerce_json_object_output", + "is_model_output_error", + "call_structured_json_model_step", + "StructuredModelCallError", + "validate_json_schema", +] + + +class StructuredModelCallError(Exception): + """Retryable structured-model call failure.""" + + def __init__( + self, + *, + context: str, + attempt_count: int, + original_exc: Exception, + ) -> None: + self.context = context + self.attempt_count = attempt_count + self.original_exc = original_exc + super().__init__(str(original_exc)) + + +def bind_structured_output(llm: Any, schema: dict[str, Any]) -> Any: + """Bind a JSON schema to an LLM when the model wrapper supports it.""" + if hasattr(llm, "with_structured_output"): + return llm.with_structured_output(schema) + return llm + + +def coerce_json_object_output(response: Any, *, context: str) -> dict[str, Any]: + """Coerce a model response into a JSON object.""" + if isinstance(response, dict): + return response + + content = getattr(response, "content", response) + if isinstance(content, dict): + return content + + if isinstance(content, list): + text_parts = [ + item.get("text", "") + for item in content + if isinstance(item, dict) and item.get("type") == "text" + ] + content = "\n".join(text_parts) + + if isinstance(content, str): + return _parse_json_text(content, context=context) + + raise ValueError(f"{context} model output has unsupported type: {type(response)!r}") + + +def is_model_output_error(exc: Exception) -> bool: + """Return whether an exception is a retryable model output formatting error.""" + class_name = exc.__class__.__name__ + module_name = exc.__class__.__module__ + return ( + class_name + in { + "JSONDecodeError", + "OutputParserException", + "SchemaValidationError", + "ValidationError", + "StructuredModelCallError", + } + or module_name.startswith("pydantic") + ) + + +def validate_json_schema( + value: Any, + schema: dict[str, Any], + *, + context: str, +) -> None: + """Validate model output against the subset of JSON Schema used locally.""" + _validate_schema_value(value, schema, path=context) + + +def call_structured_json_model_step( + *, + llm: Any, + schema: dict[str, Any], + messages: list[dict[str, Any]], + context: str, + step_name: str, + output_root: Path | None, + attempt_count: int, + raw_output_label: str | None = None, + artifact_writer: WorkflowArtifactWriter | None = None, + raw_output_writer: Callable[[dict[str, Any]], None] | None = None, +) -> dict[str, Any]: + """Call a structured-output model, validate JSON, and persist raw output.""" + model = bind_structured_output(llm, schema) + try: + response = model.invoke(messages) + raw_model_output = coerce_json_object_output(response, context=context) + validate_json_schema( + raw_model_output, + schema, + context=f"{context} output", + ) + except Exception as exc: + if is_model_output_error(exc) or isinstance(exc, ValueError): + raise StructuredModelCallError( + context=context, + attempt_count=attempt_count, + original_exc=exc, + ) from exc + raise + + if raw_output_writer is not None: + raw_output_writer(raw_model_output) + elif artifact_writer is not None: + artifact_writer.write_next_raw_model_output( + payload=raw_model_output, + label=raw_output_label, + ) + elif output_root is not None: + write_next_raw_model_output( + output_root=output_root, + step_name=step_name, + payload=raw_model_output, + label=raw_output_label, + ) + return raw_model_output + + +def _parse_json_text(content: str, *, context: str) -> dict[str, Any]: + stripped = content.strip() + if stripped.startswith("```"): + lines = stripped.splitlines() + if lines and lines[0].startswith("```"): + lines = lines[1:] + if lines and lines[-1].startswith("```"): + lines = lines[:-1] + stripped = "\n".join(lines).strip() + parsed = json.loads(stripped) + if not isinstance(parsed, dict): + raise ValueError(f"{context} model output must be a JSON object.") + return parsed + + +def _validate_schema_value(value: Any, schema: dict[str, Any], *, path: str) -> None: + expected_type = schema.get("type") + if expected_type is not None: + _validate_type(value, expected_type, path=path) + + enum_values = schema.get("enum") + if isinstance(enum_values, list) and value not in enum_values: + raise ValueError(f"{path} must be one of {enum_values}.") + + if expected_type == "object" or isinstance(value, dict): + _validate_object(value, schema, path=path) + elif expected_type == "array" or isinstance(value, list): + _validate_array(value, schema, path=path) + elif expected_type == "string" or isinstance(value, str): + _validate_string(value, schema, path=path) + elif expected_type in {"integer", "number"}: + _validate_number(value, schema, path=path) + + +def _validate_type(value: Any, expected_type: Any, *, path: str) -> None: + if isinstance(expected_type, list): + if any(_matches_type(value, item) for item in expected_type): + return + raise ValueError(f"{path} must match one of these types: {expected_type}.") + + if not _matches_type(value, expected_type): + raise ValueError(f"{path} must be {expected_type}.") + + +def _matches_type(value: Any, expected_type: str) -> bool: + if expected_type == "object": + return isinstance(value, dict) + if expected_type == "array": + return isinstance(value, list) + if expected_type == "string": + return isinstance(value, str) + if expected_type == "integer": + return isinstance(value, int) and not isinstance(value, bool) + if expected_type == "number": + return isinstance(value, int | float) and not isinstance(value, bool) + if expected_type == "boolean": + return isinstance(value, bool) + if expected_type == "null": + return value is None + return True + + +def _validate_object(value: Any, schema: dict[str, Any], *, path: str) -> None: + if not isinstance(value, dict): + return + + properties = schema.get("properties") + properties = properties if isinstance(properties, dict) else {} + + required = schema.get("required", []) + if isinstance(required, list): + missing = [key for key in required if key not in value] + if missing: + raise ValueError(f"{path} missing required keys: {missing}.") + + if schema.get("additionalProperties") is False: + extra = sorted(set(value) - set(properties)) + if extra: + raise ValueError(f"{path} has unexpected keys: {extra}.") + + for key, child_schema in properties.items(): + if key not in value or not isinstance(child_schema, dict): + continue + _validate_schema_value(value[key], child_schema, path=f"{path}.{key}") + + +def _validate_array(value: Any, schema: dict[str, Any], *, path: str) -> None: + if not isinstance(value, list): + return + + min_items = schema.get("minItems") + if isinstance(min_items, int) and len(value) < min_items: + raise ValueError(f"{path} must contain at least {min_items} items.") + + max_items = schema.get("maxItems") + if isinstance(max_items, int) and len(value) > max_items: + raise ValueError(f"{path} must contain at most {max_items} items.") + + items_schema = schema.get("items") + if not isinstance(items_schema, dict): + return + + for index, item in enumerate(value): + _validate_schema_value(item, items_schema, path=f"{path}[{index}]") + + +def _validate_string(value: Any, schema: dict[str, Any], *, path: str) -> None: + if not isinstance(value, str): + return + + min_length = schema.get("minLength") + if isinstance(min_length, int) and len(value) < min_length: + raise ValueError(f"{path} must contain at least {min_length} characters.") + + max_length = schema.get("maxLength") + if isinstance(max_length, int) and len(value) > max_length: + raise ValueError(f"{path} must contain at most {max_length} characters.") + + +def _validate_number(value: Any, schema: dict[str, Any], *, path: str) -> None: + if not isinstance(value, int | float) or isinstance(value, bool): + return + + minimum = schema.get("minimum") + if isinstance(minimum, int | float) and value < minimum: + raise ValueError(f"{path} must be greater than or equal to {minimum}.") diff --git a/embodichain/gen_sim/prompt2scene/workflows/request.py b/embodichain/gen_sim/prompt2scene/workflows/request.py new file mode 100644 index 000000000..8cd01c30f --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/workflows/request.py @@ -0,0 +1,110 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from dataclasses import dataclass +from enum import Enum +from pathlib import Path + +__all__ = ["InputKind", "Prompt2SceneInput"] + +SUPPORTED_IMAGE_SUFFIXES: frozenset[str] = frozenset({".jpg", ".jpeg", ".png"}) + + +class InputKind(str, Enum): + """Supported prompt2scene input kinds.""" + + IMAGE = "image" + TEXT = "text" + + +@dataclass(frozen=True) +class Prompt2SceneInput: + """Normalized prompt2scene input.""" + + input_kind: InputKind + output_root: Path + image_path: Path | None = None + text: str | None = None + + @classmethod + def from_cli_args( + cls, + *, + image_path: Path | None, + text: str | None, + output_root: Path, + ) -> "Prompt2SceneInput": + """Create a prompt2scene input from CLI arguments. + + Args: + image_path: Input image path, if image mode is selected. + text: Text prompt, if text mode is selected. + output_root: Directory where prompt2scene outputs are written. + + Returns: + Normalized prompt2scene input. + + Raises: + FileNotFoundError: If the image input path does not exist. + ValueError: If the image path is invalid or text input is empty. + """ + output_root = output_root.expanduser().resolve() + + if image_path is not None: + image_path = image_path.expanduser().resolve() + cls._validate_image_path(image_path) + return cls( + input_kind=InputKind.IMAGE, + image_path=image_path, + output_root=output_root, + ) + + if text is None or not text.strip(): + raise ValueError("Text input must be non-empty.") + + return cls( + input_kind=InputKind.TEXT, + text=text.strip(), + output_root=output_root, + ) + + def to_manifest(self) -> dict[str, str]: + """Convert the input to a JSON-serializable manifest.""" + manifest: dict[str, str] = { + "input_kind": self.input_kind.value, + "output_root": str(self.output_root), + } + if self.input_kind == InputKind.IMAGE: + image_path = self.image_path + manifest["image_path"] = str(image_path) + else: + text = self.text + manifest["text"] = "" if text is None else text + return manifest + + @staticmethod + def _validate_image_path(image_path: Path) -> None: + """Validate supported image input paths.""" + if not image_path.exists(): + raise FileNotFoundError(f"Image input not found: {image_path}") + if not image_path.is_file(): + raise ValueError(f"Image input is not a file: {image_path}") + if image_path.suffix.lower() not in SUPPORTED_IMAGE_SUFFIXES: + raise ValueError( + "Image input must have one of these extensions: .jpg, .jpeg, .png" + ) diff --git a/embodichain/gen_sim/prompt2scene/workflows/scene_intake/__init__.py b/embodichain/gen_sim/prompt2scene/workflows/scene_intake/__init__.py new file mode 100644 index 000000000..ac8623089 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/workflows/scene_intake/__init__.py @@ -0,0 +1,24 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from embodichain.gen_sim.prompt2scene.workflows.scene_intake.graph import ( + build_scene_intake_graph, + run_scene_intake, +) + +__all__ = ["build_scene_intake_graph", "run_scene_intake"] diff --git a/embodichain/gen_sim/prompt2scene/workflows/scene_intake/graph.py b/embodichain/gen_sim/prompt2scene/workflows/scene_intake/graph.py new file mode 100644 index 000000000..77874b15c --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/workflows/scene_intake/graph.py @@ -0,0 +1,142 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from typing import Any + +from langgraph.graph import END, StateGraph + +from embodichain.gen_sim.prompt2scene.llms import ( + OpenAICompatibleLLMCfg, + build_chat_model, +) +from embodichain.gen_sim.prompt2scene.utils import log +from embodichain.gen_sim.prompt2scene.workflows.stage_errors import ( + format_result_missing_error, +) +from embodichain.gen_sim.prompt2scene.workflows.scene_intake.schema import ( + SceneIntakeSpec, +) +from embodichain.gen_sim.prompt2scene.workflows.scene_intake.nodes import ( + call_vlm_scene_intake_node, + call_vlm_verify_scene_intake_node, + normalize_scene_intake_node, + normalize_verified_scene_intake_node, + prepare_input_node, +) +from embodichain.gen_sim.prompt2scene.workflows.scene_intake.state import ( + SceneIntakeState, +) +from embodichain.gen_sim.prompt2scene.workflows.request import Prompt2SceneInput + +__all__ = ["build_scene_intake_graph", "run_scene_intake"] + + +def route_after_normalize(state: SceneIntakeState) -> str: + """Route to retry or verify after draft scene intake normalization.""" + if state["draft_scene_intake"] is not None: + return "verify" + if state["attempt_count"] < state["max_attempts"]: + return "retry" + return "end" + + +def route_after_verified_normalize(state: SceneIntakeState) -> str: + """Route to retry or finish after scene intake verifier normalization.""" + if state["scene_intake"] is not None: + return "end" + if state["attempt_count"] < state["max_attempts"]: + return "retry" + return "end" + + +def build_scene_intake_graph(llm: Any) -> Any: + """Build the fixed LangGraph scene intake workflow.""" + graph = StateGraph(SceneIntakeState) + graph.add_node("prepare_input", prepare_input_node) + graph.add_node( + "call_vlm_scene_intake", + lambda state: call_vlm_scene_intake_node(state, llm=llm), + ) + graph.add_node("normalize_scene_intake", normalize_scene_intake_node) + graph.add_node( + "call_vlm_verify_scene_intake", + lambda state: call_vlm_verify_scene_intake_node(state, llm=llm), + ) + graph.add_node( + "normalize_verified_scene_intake", + normalize_verified_scene_intake_node, + ) + + graph.set_entry_point("prepare_input") + graph.add_edge("prepare_input", "call_vlm_scene_intake") + graph.add_edge("call_vlm_scene_intake", "normalize_scene_intake") + graph.add_conditional_edges( + "normalize_scene_intake", + route_after_normalize, + { + "retry": "call_vlm_scene_intake", + "verify": "call_vlm_verify_scene_intake", + "end": END, + }, + ) + graph.add_edge("call_vlm_verify_scene_intake", "normalize_verified_scene_intake") + graph.add_conditional_edges( + "normalize_verified_scene_intake", + route_after_verified_normalize, + { + "retry": "call_vlm_verify_scene_intake", + "end": END, + }, + ) + return graph.compile() + + +def run_scene_intake( + request: Prompt2SceneInput, + llm_cfg: OpenAICompatibleLLMCfg, +) -> SceneIntakeSpec: + """Run fixed VLM-based scene intake for one prompt2scene request.""" + llm = build_chat_model(llm_cfg) + graph = build_scene_intake_graph(llm) + result = graph.invoke( + { + "request": request, + "messages": [], + "raw_model_output": None, + "draft_scene_intake": None, + "scene_intake": None, + "attempt_count": 0, + "max_attempts": llm_cfg.max_attempts, + "last_error": None, + "errors": [], + } + ) + + scene_intake = result.get("scene_intake") + if scene_intake is not None: + return scene_intake + + error = format_result_missing_error( + "Scene intake", + "SceneIntakeSpec", + attempt_count=result.get("attempt_count", 0), + last_error=result.get("last_error"), + errors=result.get("errors", []), + ) + log.log_warning(error) + raise RuntimeError(error) diff --git a/embodichain/gen_sim/prompt2scene/workflows/scene_intake/nodes.py b/embodichain/gen_sim/prompt2scene/workflows/scene_intake/nodes.py new file mode 100644 index 000000000..8c7baf55c --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/workflows/scene_intake/nodes.py @@ -0,0 +1,211 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from typing import Any + +from embodichain.gen_sim.prompt2scene.workflows.scene_intake.schema import ( + SCENE_INTAKE_JSON_SCHEMA, + SceneIntakeSpec, +) +from embodichain.gen_sim.prompt2scene.utils import ( + log_api_request_start, + log, +) +from embodichain.gen_sim.prompt2scene.workflows.artifact_writer import ( + SCENE_INTAKE_STEP, + WorkflowArtifactWriter, +) +from embodichain.gen_sim.prompt2scene.workflows.llm_output import ( + StructuredModelCallError, + call_structured_json_model_step, +) +from embodichain.gen_sim.prompt2scene.workflows.stage_errors import ( + format_attempt_error, +) +from embodichain.gen_sim.prompt2scene.workflows.scene_intake.prompts import ( + build_scene_intake_messages, + build_scene_intake_verifier_messages, +) +from embodichain.gen_sim.prompt2scene.workflows.scene_intake.state import ( + SceneIntakeState, +) +from embodichain.gen_sim.prompt2scene.workflows.scene_intake.utils import ( + build_scene_intake_spec, +) + +__all__ = [ + "call_vlm_scene_intake_node", + "call_vlm_verify_scene_intake_node", + "normalize_scene_intake_node", + "normalize_verified_scene_intake_node", + "prepare_input_node", +] + + +def prepare_input_node(state: SceneIntakeState) -> dict[str, object]: + """Prepare chat messages for the scene intake model call.""" + return {"messages": build_scene_intake_messages(state["request"])} + + +def call_vlm_scene_intake_node( + state: SceneIntakeState, + *, + llm: Any, +) -> dict[str, object]: + """Call the configured VLM for fixed scene intake extraction.""" + attempt_count = state["attempt_count"] + 1 + + try: + log_api_request_start( + step=SCENE_INTAKE_STEP, + request="extract", + attempt=attempt_count, + ) + artifact_writer = WorkflowArtifactWriter( + state["request"].output_root, + SCENE_INTAKE_STEP, + ) + raw_model_output = call_structured_json_model_step( + llm=llm, + schema=SCENE_INTAKE_JSON_SCHEMA, + messages=state["messages"], + context="Scene intake", + step_name=SCENE_INTAKE_STEP, + output_root=None, + attempt_count=attempt_count, + raw_output_label="extract", + artifact_writer=artifact_writer, + ) + except StructuredModelCallError as exc: + error = format_attempt_error("Scene intake", attempt_count, exc) + log.log_warning(error) + return { + "attempt_count": attempt_count, + "raw_model_output": None, + "last_error": error, + "errors": state["errors"] + [error], + } + + return { + "attempt_count": attempt_count, + "raw_model_output": raw_model_output, + "last_error": None, + } + + +def normalize_scene_intake_node(state: SceneIntakeState) -> dict[str, object]: + """Normalize raw VLM JSON into a draft scene intake schema.""" + raw_model_output = state["raw_model_output"] + if raw_model_output is None: + return {} + + try: + scene_intake = build_scene_intake_spec( + request=state["request"], + model_output=raw_model_output, + ) + except ValueError as exc: + error = format_attempt_error("Scene intake", state["attempt_count"], exc) + return { + "draft_scene_intake": None, + "last_error": error, + "errors": state["errors"] + [error], + } + + return {"draft_scene_intake": scene_intake, "scene_intake": None} + + +def call_vlm_verify_scene_intake_node( + state: SceneIntakeState, + *, + llm: Any, +) -> dict[str, object]: + """Ask VLM to verify and correct scene-intake grouping and counts.""" + draft_scene_intake = state["draft_scene_intake"] + if draft_scene_intake is None: + return {} + + attempt_count = state["attempt_count"] + 1 + messages = build_scene_intake_verifier_messages( + request=state["request"], + scene_intake=draft_scene_intake, + ) + + try: + log_api_request_start( + step=SCENE_INTAKE_STEP, + request="verify", + attempt=attempt_count, + ) + artifact_writer = WorkflowArtifactWriter( + state["request"].output_root, + SCENE_INTAKE_STEP, + ) + raw_model_output = call_structured_json_model_step( + llm=llm, + schema=SCENE_INTAKE_JSON_SCHEMA, + messages=messages, + context="Scene intake verifier", + step_name=SCENE_INTAKE_STEP, + output_root=None, + attempt_count=attempt_count, + raw_output_label="verify", + artifact_writer=artifact_writer, + ) + except StructuredModelCallError as exc: + error = format_attempt_error("Scene intake verifier", attempt_count, exc) + log.log_warning(error) + return { + "attempt_count": attempt_count, + "raw_model_output": None, + "scene_intake": None, + "last_error": error, + "errors": state["errors"] + [error], + } + + return { + "attempt_count": attempt_count, + "raw_model_output": raw_model_output, + "scene_intake": None, + "last_error": None, + } + + +def normalize_verified_scene_intake_node( + state: SceneIntakeState, +) -> dict[str, object]: + """Normalize verifier output into the final scene intake schema.""" + raw_model_output = state["raw_model_output"] + if raw_model_output is None: + return {} + + try: + scene_intake = build_scene_intake_spec( + request=state["request"], + model_output=raw_model_output, + ) + except ValueError as exc: + error = format_attempt_error("Scene intake verifier", state["attempt_count"], exc) + log.log_warning(error) + return { + "scene_intake": None, + "last_error": error, + "errors": state["errors"] + [error], + } + + return {"scene_intake": scene_intake, "last_error": None} diff --git a/embodichain/gen_sim/prompt2scene/workflows/scene_intake/prompts.py b/embodichain/gen_sim/prompt2scene/workflows/scene_intake/prompts.py new file mode 100644 index 000000000..611c5bf95 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/workflows/scene_intake/prompts.py @@ -0,0 +1,197 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +import json +from typing import Any + +from embodichain.gen_sim.prompt2scene.prompts import render_prompt +from embodichain.gen_sim.prompt2scene.workflows.request import ( + InputKind, + Prompt2SceneInput, +) +from embodichain.gen_sim.prompt2scene.utils.io import image_to_data_url +from embodichain.gen_sim.prompt2scene.workflows.scene_intake.schema import ( + SceneIntakeSpec, +) + +__all__ = [ + "build_scene_intake_messages", + "build_scene_intake_verifier_messages", +] + +SCENE_INTAKE_PROMPT_NAME = "scene_intake.yaml" + + +def build_scene_intake_messages(request: Prompt2SceneInput) -> list[dict[str, Any]]: + """Build LangChain-compatible messages for scene intake.""" + if request.input_kind == InputKind.TEXT: + return _build_text_messages(request) + return _build_image_messages(request) + + +def build_scene_intake_verifier_messages( + *, + request: Prompt2SceneInput, + scene_intake: SceneIntakeSpec, +) -> list[dict[str, Any]]: + """Build messages for scene-intake group and count verification.""" + scene_intake_json = json.dumps( + { + "table": { + "name": scene_intake.table.name, + "description": scene_intake.table.description, + "complete_table_description": ( + scene_intake.table.complete_table_description + ), + "is_complete_visible_table": ( + scene_intake.table.is_complete_visible_table + ), + "class_candidate": list(scene_intake.table.class_candidate), + }, + "assets": [ + { + "name": asset.name, + "description": asset.description, + "class_candidate": list(asset.class_candidate), + "count": asset.count, + } + for asset in scene_intake.assets + ], + }, + ensure_ascii=False, + indent=2, + ) + if request.input_kind == InputKind.TEXT: + return _build_text_verifier_messages( + request=request, + scene_intake_json=scene_intake_json, + ) + return _build_image_verifier_messages( + request=request, + scene_intake_json=scene_intake_json, + ) + + +def _build_text_messages(request: Prompt2SceneInput) -> list[dict[str, Any]]: + return [ + { + "role": "system", + "content": render_prompt(SCENE_INTAKE_PROMPT_NAME, prompt_key="text_system"), + }, + { + "role": "user", + "content": render_prompt( + SCENE_INTAKE_PROMPT_NAME, + {"text": request.text or ""}, + prompt_key="text_user", + ), + }, + ] + + +def _build_image_messages(request: Prompt2SceneInput) -> list[dict[str, Any]]: + image_path = request.image_path + if image_path is None: + raise ValueError("Image input requires image_path.") + + return [ + { + "role": "system", + "content": render_prompt(SCENE_INTAKE_PROMPT_NAME, prompt_key="image_system"), + }, + { + "role": "user", + "content": [ + { + "type": "text", + "text": render_prompt( + SCENE_INTAKE_PROMPT_NAME, + prompt_key="image_user", + ), + }, + { + "type": "image_url", + "image_url": {"url": image_to_data_url(image_path)}, + }, + ], + }, + ] + + +def _build_text_verifier_messages( + *, + request: Prompt2SceneInput, + scene_intake_json: str, +) -> list[dict[str, Any]]: + return [ + { + "role": "system", + "content": render_prompt( + SCENE_INTAKE_PROMPT_NAME, + prompt_key="verifier_system", + ), + }, + { + "role": "user", + "content": render_prompt( + SCENE_INTAKE_PROMPT_NAME, + { + "text": request.text or "", + "scene_intake_json": scene_intake_json, + }, + prompt_key="verifier_text_user", + ), + }, + ] + + +def _build_image_verifier_messages( + *, + request: Prompt2SceneInput, + scene_intake_json: str, +) -> list[dict[str, Any]]: + image_path = request.image_path + if image_path is None: + raise ValueError("Image input requires image_path.") + + return [ + { + "role": "system", + "content": render_prompt( + SCENE_INTAKE_PROMPT_NAME, + prompt_key="verifier_system", + ), + }, + { + "role": "user", + "content": [ + { + "type": "text", + "text": render_prompt( + SCENE_INTAKE_PROMPT_NAME, + {"scene_intake_json": scene_intake_json}, + prompt_key="verifier_image_user", + ), + }, + { + "type": "image_url", + "image_url": {"url": image_to_data_url(image_path)}, + }, + ], + }, + ] diff --git a/embodichain/gen_sim/prompt2scene/workflows/scene_intake/schema.py b/embodichain/gen_sim/prompt2scene/workflows/scene_intake/schema.py new file mode 100644 index 000000000..80c9ca27c --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/workflows/scene_intake/schema.py @@ -0,0 +1,244 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any + +from embodichain.gen_sim.prompt2scene.workflows.request import ( + InputKind, + Prompt2SceneInput, +) + +__all__ = [ + "SCENE_INTAKE_JSON_SCHEMA", + "SceneIntakeAsset", + "SceneIntakeInputRecord", + "SceneIntakeSpec", + "SceneIntakeTable", +] + +SCENE_INTAKE_JSON_SCHEMA: dict[str, Any] = { + "title": "SceneIntakeModelOutput", + "description": ( + "Objects and table information extracted from a text or image input." + ), + "type": "object", + "additionalProperties": False, + "properties": { + "table": { + "type": "object", + "additionalProperties": False, + "properties": { + "name": { + "type": "string", + "description": ( + "Canonical English class name for the visible table " + "or tabletop target, such as table, desk, dining_table, " + "coffee_table, workbench, or tabletop." + ), + }, + "description": { + "type": "string", + "minLength": 20, + "maxLength": 180, + "description": ( + "One concise standalone appearance description of the " + "visible table or tabletop region." + ), + }, + "complete_table_description": { + "type": "string", + "minLength": 20, + "maxLength": 220, + "description": ( + "One concise standalone description of a complete table " + "asset for text-to-3D generation, matching the visible " + "tabletop color, material, and texture." + ), + }, + "is_complete_visible_table": { + "type": "boolean", + "description": ( + "For image input, whether a mostly complete table is " + "visible and suitable as the final table geometry source. " + "For text input, this should be false." + ), + }, + "class_candidate": { + "type": "array", + "minItems": 5, + "maxItems": 5, + "description": ( + "Exactly five likely class names for segmenting the " + "visible table or tabletop target." + ), + "items": { + "type": "string", + "minLength": 1, + }, + }, + }, + "required": [ + "name", + "description", + "complete_table_description", + "is_complete_visible_table", + "class_candidate", + ], + }, + "assets": { + "type": "array", + "description": ( + "Object category groups on or intended for the tabletop scene." + ), + "items": { + "type": "object", + "additionalProperties": False, + "properties": { + "name": { + "type": "string", + "description": ( + "Canonical English object name, singular, " + "snake_case preferred." + ), + }, + "description": { + "type": "string", + "minLength": 20, + "maxLength": 180, + "description": ( + "One concise appearance description of the object for " + "image and 3D geometry generation." + ), + }, + "class_candidate": { + "type": "array", + "minItems": 5, + "maxItems": 5, + "description": ( + "Exactly five likely object class names for later " + "image detection or segmentation." + ), + "items": { + "type": "string", + "minLength": 1, + }, + }, + "count": { + "type": "integer", + "description": ( + "Number of repeated instances in this object category " + "group. Only group objects that can share the same name, " + "description, and class_candidate list." + ), + "minimum": 1, + }, + }, + "required": ["name", "description", "class_candidate", "count"], + }, + }, + }, + "required": ["table", "assets"], +} + + +@dataclass(frozen=True) +class SceneIntakeInputRecord: + """Normalized input source recorded by scene intake.""" + + input_kind: InputKind + text: str | None = None + image_path: str | None = None + + @classmethod + def from_request(cls, request: Prompt2SceneInput) -> "SceneIntakeInputRecord": + """Create an input record from a prompt2scene request.""" + return cls( + input_kind=request.input_kind, + text=request.text, + image_path=str(request.image_path) if request.image_path else None, + ) + + def to_manifest(self) -> dict[str, str | None]: + """Convert the input record to JSON-safe data.""" + return { + "input_kind": self.input_kind.value, + "text": self.text, + "image_path": self.image_path, + } + + +@dataclass(frozen=True) +class SceneIntakeTable: + """Table/support information extracted during scene intake.""" + + id: str = "table" + name: str = "table" + description: str = "" + complete_table_description: str = "" + is_complete_visible_table: bool = False + class_candidate: list[str] = field(default_factory=list) + + def to_manifest(self) -> dict[str, object]: + """Convert the table record to JSON-safe data.""" + return { + "id": self.id, + "name": self.name, + "description": self.description, + "complete_table_description": self.complete_table_description, + "is_complete_visible_table": self.is_complete_visible_table, + "class_candidate": list(self.class_candidate), + } + + +@dataclass(frozen=True) +class SceneIntakeAsset: + """Object category group extracted during scene intake.""" + + id: str + name: str + count: int = 1 + description: str = "" + class_candidate: list[str] = field(default_factory=list) + + def to_manifest(self) -> dict[str, object]: + """Convert the asset record to JSON-safe data.""" + return { + "id": self.id, + "name": self.name, + "count": self.count, + "description": self.description, + "class_candidate": list(self.class_candidate), + } + + +@dataclass(frozen=True) +class SceneIntakeSpec: + """Unified first-step scene intake output for text and image inputs.""" + + input: SceneIntakeInputRecord + table: SceneIntakeTable + assets: list[SceneIntakeAsset] + + def to_manifest(self) -> dict[str, object]: + """Convert the intake spec to JSON-safe data.""" + return { + "input": self.input.to_manifest(), + "table": self.table.to_manifest(), + "assets": [asset.to_manifest() for asset in self.assets], + } diff --git a/embodichain/gen_sim/prompt2scene/workflows/scene_intake/state.py b/embodichain/gen_sim/prompt2scene/workflows/scene_intake/state.py new file mode 100644 index 000000000..7a96619fb --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/workflows/scene_intake/state.py @@ -0,0 +1,37 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from typing import Any + +from embodichain.gen_sim.prompt2scene.workflows.request import Prompt2SceneInput +from embodichain.gen_sim.prompt2scene.workflows.scene_intake.schema import ( + SceneIntakeSpec, +) +from embodichain.gen_sim.prompt2scene.workflows.attempt_state import AttemptState + +__all__ = ["SceneIntakeState"] + + +class SceneIntakeState(AttemptState): + """LangGraph state for the fixed scene intake workflow.""" + + request: Prompt2SceneInput + messages: list[Any] + raw_model_output: dict[str, Any] | None + draft_scene_intake: SceneIntakeSpec | None + scene_intake: SceneIntakeSpec | None diff --git a/embodichain/gen_sim/prompt2scene/workflows/scene_intake/utils.py b/embodichain/gen_sim/prompt2scene/workflows/scene_intake/utils.py new file mode 100644 index 000000000..e49fe9b3d --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/workflows/scene_intake/utils.py @@ -0,0 +1,229 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +import re +from typing import Any + +from embodichain.gen_sim.prompt2scene.workflows.request import Prompt2SceneInput +from embodichain.gen_sim.prompt2scene.workflows.scene_intake.schema import ( + SceneIntakeAsset, + SceneIntakeInputRecord, + SceneIntakeSpec, + SceneIntakeTable, +) + +__all__ = ["build_scene_intake_spec", "normalize_asset_name"] + + +def normalize_asset_name(name: str) -> str: + """Normalize an object name for stable asset IDs.""" + normalized = name.strip().lower() + normalized = normalized.replace("-", " ").replace("/", " ") + normalized = re.sub(r"[^a-z0-9\s_]", "", normalized) + normalized = re.sub(r"\s+", "_", normalized) + normalized = re.sub(r"_+", "_", normalized).strip("_") + return normalized or "object" + + +def build_scene_intake_spec( + *, + request: Prompt2SceneInput, + model_output: dict[str, Any], +) -> SceneIntakeSpec: + """Normalize raw VLM JSON into the stable scene intake schema.""" + _validate_exact_keys( + model_output, + allowed_keys={"table", "assets"}, + context="Scene intake model output", + ) + input_record = SceneIntakeInputRecord.from_request(request) + table = _parse_table(_require_mapping(model_output.get("table"), "table")) + assets = _parse_assets(_require_list(model_output.get("assets"), "assets")) + return SceneIntakeSpec(input=input_record, table=table, assets=assets) + + +def _parse_table(raw_table: dict[str, Any]) -> SceneIntakeTable: + _validate_exact_keys( + raw_table, + allowed_keys={ + "name", + "description", + "complete_table_description", + "is_complete_visible_table", + "class_candidate", + }, + context="Scene intake table", + ) + + if "name" not in raw_table: + raise ValueError("Scene intake table.name is required.") + raw_name = str(raw_table["name"]).strip() + if not raw_name: + raise ValueError("Scene intake table.name must be non-empty.") + name = normalize_asset_name(raw_name) + + if "description" not in raw_table: + raise ValueError("Scene intake table.description is required.") + description = str(raw_table["description"]).strip() + if not description: + raise ValueError("Scene intake table.description must be non-empty.") + + if "complete_table_description" not in raw_table: + raise ValueError("Scene intake table.complete_table_description is required.") + complete_table_description = str( + raw_table["complete_table_description"] + ).strip() + if not complete_table_description: + raise ValueError( + "Scene intake table.complete_table_description must be non-empty." + ) + + if "is_complete_visible_table" not in raw_table: + raise ValueError("Scene intake table.is_complete_visible_table is required.") + is_complete_visible_table = raw_table["is_complete_visible_table"] + if not isinstance(is_complete_visible_table, bool): + raise ValueError( + "Scene intake table.is_complete_visible_table must be a boolean." + ) + + class_candidate = _parse_class_candidate( + raw_table.get("class_candidate"), + asset_index="table", + raw_name=name, + ) + + return SceneIntakeTable( + name=name, + description=description, + complete_table_description=complete_table_description, + is_complete_visible_table=is_complete_visible_table, + class_candidate=class_candidate, + ) + + +def _parse_assets(raw_assets: list[Any]) -> list[SceneIntakeAsset]: + assets: list[SceneIntakeAsset] = [] + seen_names: set[str] = set() + + for asset_index, raw_asset in enumerate(raw_assets): + if not isinstance(raw_asset, dict): + raise ValueError(f"Scene intake asset {asset_index} must be an object.") + _validate_exact_keys( + raw_asset, + allowed_keys={"name", "description", "class_candidate", "count"}, + context=f"Scene intake asset {asset_index}", + ) + + if "name" not in raw_asset: + raise ValueError(f"Scene intake asset {asset_index}.name is required.") + raw_name = str(raw_asset["name"]).strip() + if not raw_name: + raise ValueError( + f"Scene intake asset {asset_index}.name must be non-empty." + ) + + if "description" not in raw_asset: + raise ValueError( + f"Scene intake asset {asset_index}.description is required." + ) + description = str(raw_asset["description"]).strip() + if not description: + raise ValueError( + f"Scene intake asset {asset_index}.description must be non-empty." + ) + + class_candidate = _parse_class_candidate( + raw_asset.get("class_candidate"), + asset_index=asset_index, + raw_name=raw_name, + ) + count = _parse_count(raw_asset.get("count"), asset_index=asset_index) + base_name = normalize_asset_name(raw_name) + name = base_name + suffix = 2 + while name in seen_names: + name = f"{base_name}_{suffix}" + suffix += 1 + seen_names.add(name) + assets.append( + SceneIntakeAsset( + id=f"interact_{name}", + name=name, + count=count, + description=description, + class_candidate=class_candidate, + ) + ) + return assets + + +def _parse_class_candidate( + raw_class_candidate: Any, + *, + asset_index: int | str, + raw_name: str, +) -> list[str]: + if not isinstance(raw_class_candidate, list): + raise ValueError( + f"Scene intake asset {asset_index}.class_candidate must be a list." + ) + class_candidate = [normalize_asset_name(str(item)) for item in raw_class_candidate] + if len(class_candidate) != 5: + raise ValueError( + f"Scene intake asset {asset_index}.class_candidate must contain exactly five entries." + ) + if any(not candidate for candidate in class_candidate): + raise ValueError( + f"Scene intake asset {asset_index}.class_candidate has empty entries." + ) + if class_candidate[0] != normalize_asset_name(raw_name): + raise ValueError( + f"Scene intake asset {asset_index}.class_candidate[0] must equal name." + ) + return class_candidate + + +def _parse_count(raw_count: Any, *, asset_index: int) -> int: + if not isinstance(raw_count, int) or isinstance(raw_count, bool): + raise ValueError(f"Scene intake asset {asset_index}.count must be an integer.") + if raw_count < 1: + raise ValueError(f"Scene intake asset {asset_index}.count must be >= 1.") + return raw_count + + +def _validate_exact_keys( + value: dict[str, Any], + *, + allowed_keys: set[str], + context: str, +) -> None: + extra_keys = sorted(set(value) - allowed_keys) + if extra_keys: + raise ValueError(f"{context} has unexpected keys: {extra_keys}.") + + +def _require_mapping(value: Any, context: str) -> dict[str, Any]: + if not isinstance(value, dict): + raise ValueError(f"{context} must be an object.") + return value + + +def _require_list(value: Any, context: str) -> list[Any]: + if not isinstance(value, list): + raise ValueError(f"{context} must be a list.") + return value diff --git a/embodichain/gen_sim/prompt2scene/workflows/spatial.py b/embodichain/gen_sim/prompt2scene/workflows/spatial.py new file mode 100644 index 000000000..b5f938685 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/workflows/spatial.py @@ -0,0 +1,309 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +__all__ = [ + "GRID_VALUE_LIST", + "GRID_VALUES", + "RELATION_VALUE_LIST", + "RELATION_VALUES", + "assign_grids_from_anchor_and_orders", + "derive_relations_from_orders", + "invert_relation", + "normalize_relation", + "transitive_relation_closure", + "validate_exact_asset_id_coverage", +] + +RELATION_VALUE_LIST = ["left_of", "front_of"] +RELATION_VALUES = frozenset(RELATION_VALUE_LIST) +INVERSE_RELATIONS = { + "left_of": "right_of", + "right_of": "left_of", + "front_of": "behind", + "behind": "front_of", +} + +GRID_VALUE_LIST = [ + "center", + "front", + "back", + "left_center", + "right_center", + "left_front", + "right_front", + "left_back", + "right_back", +] +GRID_VALUES = frozenset(GRID_VALUE_LIST) + + +def validate_exact_asset_id_coverage( + *, + values: list[str], + expected_asset_ids: list[str], + context: str, +) -> None: + """Validate that values contain every expected asset id exactly once.""" + expected = set(expected_asset_ids) + actual = set(values) + duplicates = sorted({asset_id for asset_id in values if values.count(asset_id) > 1}) + missing = sorted(expected - actual) + unknown = sorted(actual - expected) + if duplicates: + raise ValueError(f"{context} has duplicate asset ids: {duplicates}.") + if missing: + raise ValueError(f"{context} is missing asset ids: {missing}.") + if unknown: + raise ValueError(f"{context} has unknown asset ids: {unknown}.") + + +def assign_grids_from_anchor_and_orders( + *, + anchor_asset_id: str, + anchor_grid: str, + x_order: list[list[str]], + y_order: list[list[str]], + asset_ids: list[str], +) -> dict[str, str]: + """Assign 9-grid labels from one anchor grid and two object orderings.""" + anchor_x, anchor_y = _split_grid(anchor_grid) + x_indices = _order_indices(x_order) + y_indices = _order_indices(y_order) + anchor_x_index = x_indices[anchor_asset_id] + anchor_y_index = y_indices[anchor_asset_id] + + grids: dict[str, str] = {} + for asset_id in asset_ids: + x_label = _axis_label_from_anchor( + index=x_indices[asset_id], + anchor_index=anchor_x_index, + anchor_label=anchor_x, + before_label="left", + after_label="right", + ) + y_label = _axis_label_from_anchor( + index=y_indices[asset_id], + anchor_index=anchor_y_index, + anchor_label=anchor_y, + before_label="front", + after_label="back", + ) + grids[asset_id] = _join_grid(x_label=x_label, y_label=y_label) + return grids + + +def invert_relation(relation: str) -> str: + """Return the inverse of a supported spatial relation.""" + if relation not in INVERSE_RELATIONS: + raise ValueError(f"Unsupported spatial relation: {relation!r}.") + return INVERSE_RELATIONS[relation] + + +def normalize_relation( + *, + subject: str, + relation: str, + object_id: str, +) -> tuple[str, str, str]: + """Normalize a relation into a canonical directional axis edge.""" + if relation == "left_of": + return subject, "left_of", object_id + if relation == "right_of": + return object_id, "left_of", subject + if relation == "front_of": + return subject, "front_of", object_id + if relation == "behind": + return object_id, "front_of", subject + raise ValueError(f"Unsupported spatial relation: {relation!r}.") + + +def transitive_relation_closure( + relations: list[dict[str, str]], +) -> list[dict[str, str]]: + """Expand canonical left/front relations with transitive closure.""" + direct_edges: dict[str, set[tuple[str, str]]] = { + "left_of": set(), + "front_of": set(), + } + input_edges: set[tuple[str, str, str]] = set() + for relation_record in relations: + subject = relation_record["subject"] + relation = relation_record["relation"] + object_id = relation_record["object"] + canonical_subject, canonical_relation, canonical_object = normalize_relation( + subject=subject, + relation=relation, + object_id=object_id, + ) + if canonical_subject == canonical_object: + raise ValueError("Spatial relation cannot reference the same object.") + edge = (canonical_subject, canonical_object) + inverse_edge = (canonical_object, canonical_subject) + if inverse_edge in direct_edges[canonical_relation]: + raise ValueError( + "Conflicting spatial relations: " + f"{canonical_subject!r} {canonical_relation} {canonical_object!r}." + ) + direct_edges[canonical_relation].add(edge) + input_edges.add((subject, relation, object_id)) + + output: list[dict[str, str]] = [] + seen: set[tuple[str, str, str]] = set() + for canonical_relation, edges in direct_edges.items(): + for subject, object_id in sorted(_transitive_edges(edges)): + _append_relation( + output=output, + seen=seen, + subject=subject, + relation=canonical_relation, + object_id=object_id, + source=( + "input" + if (subject, canonical_relation, object_id) in input_edges + else "closure" + ), + ) + return output + + +def derive_relations_from_orders( + *, + x_order: list[list[str]], + y_order: list[list[str]], +) -> list[dict[str, str]]: + """Derive canonical relations from adjacent order groups.""" + relations: list[dict[str, str]] = [] + relations.extend(_relations_from_order_groups(x_order, relation="left_of")) + relations.extend(_relations_from_order_groups(y_order, relation="front_of")) + closed = transitive_relation_closure(relations) + return [ + { + **relation, + "source": "order" if relation["source"] == "input" else relation["source"], + } + for relation in closed + ] + + +def _order_indices(order: list[list[str]]) -> dict[str, int]: + return { + asset_id: group_index + for group_index, group in enumerate(order) + for asset_id in group + } + + +def _split_grid(grid: str) -> tuple[str, str]: + if grid == "center": + return "center", "center" + if grid in {"front", "back"}: + return "center", grid + if grid in {"left_center", "right_center"}: + return grid.split("_", maxsplit=1)[0], "center" + x_label, y_label = grid.split("_", maxsplit=1) + return x_label, y_label + + +def _axis_label_from_anchor( + *, + index: int, + anchor_index: int, + anchor_label: str, + before_label: str, + after_label: str, +) -> str: + if index < anchor_index: + return before_label + if index > anchor_index: + return after_label + return anchor_label + + +def _join_grid(*, x_label: str, y_label: str) -> str: + if x_label == "center" and y_label == "center": + return "center" + if x_label == "center": + return y_label + if y_label == "center": + return f"{x_label}_center" + return f"{x_label}_{y_label}" + + +def _relations_from_order_groups( + order_groups: list[list[str]], + *, + relation: str, +) -> list[dict[str, str]]: + relations: list[dict[str, str]] = [] + for earlier_group, later_group in zip(order_groups, order_groups[1:]): + for subject in earlier_group: + for object_id in later_group: + relations.append( + { + "subject": subject, + "relation": relation, + "object": object_id, + "source": "input", + } + ) + return relations + + +def _transitive_edges( + edges: set[tuple[str, str]], +) -> set[tuple[str, str]]: + adjacency: dict[str, set[str]] = {} + for subject, object_id in edges: + adjacency.setdefault(subject, set()).add(object_id) + adjacency.setdefault(object_id, set()) + + closure: set[tuple[str, str]] = set(edges) + for start in adjacency: + stack = list(adjacency[start]) + visited: set[str] = set() + while stack: + current = stack.pop() + if current in visited: + continue + visited.add(current) + closure.add((start, current)) + stack.extend(adjacency.get(current, ())) + return closure + + +def _append_relation( + *, + output: list[dict[str, str]], + seen: set[tuple[str, str, str]], + subject: str, + relation: str, + object_id: str, + source: str, +) -> None: + key = (subject, relation, object_id) + if key in seen: + return + seen.add(key) + output.append( + { + "subject": subject, + "relation": relation, + "object": object_id, + "source": source, + } + ) diff --git a/embodichain/gen_sim/prompt2scene/workflows/stage_errors.py b/embodichain/gen_sim/prompt2scene/workflows/stage_errors.py new file mode 100644 index 000000000..f8d8c2303 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/workflows/stage_errors.py @@ -0,0 +1,40 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +__all__ = ["format_attempt_error", "format_result_missing_error"] + + +def format_attempt_error(stage_name: str, attempt_count: int, exc: Exception) -> str: + """Format a retryable stage failure message.""" + return f"{stage_name} attempt {attempt_count} failed: {exc}" + + +def format_result_missing_error( + stage_name: str, + result_name: str, + *, + attempt_count: int, + last_error: str | None, + errors: list[str], +) -> str: + """Format a missing-final-result error message.""" + return ( + f"{stage_name} failed to produce a {result_name} after " + f"{attempt_count} attempts. Last error: {last_error}. " + f"All retryable errors: {errors}" + ) diff --git a/embodichain/gen_sim/prompt2scene/workflows/text_relations/__init__.py b/embodichain/gen_sim/prompt2scene/workflows/text_relations/__init__.py new file mode 100644 index 000000000..e2c035398 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/workflows/text_relations/__init__.py @@ -0,0 +1,24 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from embodichain.gen_sim.prompt2scene.workflows.text_relations.graph import ( + build_text_relations_graph, + run_text_relations, +) + +__all__ = ["build_text_relations_graph", "run_text_relations"] diff --git a/embodichain/gen_sim/prompt2scene/workflows/text_relations/graph.py b/embodichain/gen_sim/prompt2scene/workflows/text_relations/graph.py new file mode 100644 index 000000000..f6aa60785 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/workflows/text_relations/graph.py @@ -0,0 +1,124 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from pathlib import Path +from typing import Any + +from langgraph.graph import END, StateGraph + +from embodichain.gen_sim.prompt2scene.llms import ( + OpenAICompatibleLLMCfg, + build_chat_model, +) +from embodichain.gen_sim.prompt2scene.utils import log +from embodichain.gen_sim.prompt2scene.workflows.stage_errors import ( + format_result_missing_error, +) +from embodichain.gen_sim.prompt2scene.workflows.scene_intake.schema import ( + SceneIntakeSpec, +) +from embodichain.gen_sim.prompt2scene.workflows.text_relations.nodes import ( + call_llm_text_relations_node, + normalize_text_relations_node, + prepare_text_relation_messages_node, +) +from embodichain.gen_sim.prompt2scene.workflows.text_relations.schema import ( + TextRelationSpec, +) +from embodichain.gen_sim.prompt2scene.workflows.text_relations.state import ( + TextRelationsState, +) +from embodichain.gen_sim.prompt2scene.workflows.request import Prompt2SceneInput + +__all__ = ["build_text_relations_graph", "run_text_relations"] + + +def route_after_text_relation_normalize(state: TextRelationsState) -> str: + """Route to retry or finish after text relation normalization.""" + if state["text_relations"] is not None: + return "end" + if state["attempt_count"] < state["max_attempts"]: + return "retry" + return "end" + + +def build_text_relations_graph(llm: Any) -> Any: + """Build the fixed text spatial-relation extraction workflow.""" + graph = StateGraph(TextRelationsState) + graph.add_node( + "prepare_text_relation_messages", + prepare_text_relation_messages_node, + ) + graph.add_node( + "call_llm_text_relations", + lambda state: call_llm_text_relations_node(state, llm=llm), + ) + graph.add_node("normalize_text_relations", normalize_text_relations_node) + + graph.set_entry_point("prepare_text_relation_messages") + graph.add_edge("prepare_text_relation_messages", "call_llm_text_relations") + graph.add_edge("call_llm_text_relations", "normalize_text_relations") + graph.add_conditional_edges( + "normalize_text_relations", + route_after_text_relation_normalize, + { + "retry": "call_llm_text_relations", + "end": END, + }, + ) + return graph.compile() + + +def run_text_relations( + request: Prompt2SceneInput, + *, + scene_intake: SceneIntakeSpec, + llm_cfg: OpenAICompatibleLLMCfg, + output_root: Path, +) -> TextRelationSpec: + """Run text spatial-relation extraction for one prompt2scene request.""" + llm = build_chat_model(llm_cfg) + graph = build_text_relations_graph(llm) + result = graph.invoke( + { + "request": request, + "scene_intake": scene_intake, + "output_root": output_root, + "messages": [], + "raw_model_output": None, + "text_relations": None, + "attempt_count": 0, + "max_attempts": llm_cfg.max_attempts, + "last_error": None, + "errors": [], + } + ) + + text_relations = result.get("text_relations") + if text_relations is not None: + return text_relations + + error = format_result_missing_error( + "Text relations", + "TextRelationSpec", + attempt_count=result.get("attempt_count", 0), + last_error=result.get("last_error"), + errors=result.get("errors", []), + ) + log.log_warning(error) + raise RuntimeError(error) diff --git a/embodichain/gen_sim/prompt2scene/workflows/text_relations/nodes.py b/embodichain/gen_sim/prompt2scene/workflows/text_relations/nodes.py new file mode 100644 index 000000000..67b1fc3c1 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/workflows/text_relations/nodes.py @@ -0,0 +1,144 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from typing import Any + +from embodichain.gen_sim.prompt2scene.workflows.request import InputKind +from embodichain.gen_sim.prompt2scene.workflows.text_relations.schema import ( + TEXT_RELATIONS_JSON_SCHEMA, + TextRelationSpec, +) +from embodichain.gen_sim.prompt2scene.utils import ( + log_api_request_start, + log, +) +from embodichain.gen_sim.prompt2scene.workflows.artifact_writer import ( + TEXT_RELATIONS_STEP, + WorkflowArtifactWriter, +) +from embodichain.gen_sim.prompt2scene.workflows.llm_output import ( + StructuredModelCallError, + call_structured_json_model_step, +) +from embodichain.gen_sim.prompt2scene.workflows.stage_errors import ( + format_attempt_error, +) +from embodichain.gen_sim.prompt2scene.workflows.text_relations.prompts import ( + build_text_relation_messages, +) +from embodichain.gen_sim.prompt2scene.workflows.text_relations.state import ( + TextRelationsState, +) +from embodichain.gen_sim.prompt2scene.workflows.text_relations.utils import ( + build_text_relation_spec, +) + +__all__ = [ + "call_llm_text_relations_node", + "normalize_text_relations_node", + "prepare_text_relation_messages_node", +] + + +def prepare_text_relation_messages_node( + state: TextRelationsState, +) -> dict[str, object]: + """Prepare text-relation extraction messages.""" + request = state["request"] + if request.input_kind != InputKind.TEXT: + raise ValueError("Text relations requires a text input.") + return { + "messages": build_text_relation_messages( + request=request, + scene_intake=state["scene_intake"], + ) + } + + +def call_llm_text_relations_node( + state: TextRelationsState, + *, + llm: Any, +) -> dict[str, object]: + """Call LLM to extract explicit text spatial constraints.""" + attempt_count = state["attempt_count"] + 1 + artifact_writer = WorkflowArtifactWriter( + state["output_root"], + TEXT_RELATIONS_STEP, + ) + + try: + log_api_request_start( + step=TEXT_RELATIONS_STEP, + request="extract", + attempt=attempt_count, + ) + raw_model_output = call_structured_json_model_step( + llm=llm, + schema=TEXT_RELATIONS_JSON_SCHEMA, + messages=state["messages"], + context="Text relations", + step_name=TEXT_RELATIONS_STEP, + output_root=None, + attempt_count=attempt_count, + raw_output_label="extract", + artifact_writer=artifact_writer, + ) + except StructuredModelCallError as exc: + error = format_attempt_error("Text relations", attempt_count, exc) + log.log_warning(error) + return { + "attempt_count": attempt_count, + "raw_model_output": None, + "last_error": error, + "errors": state["errors"] + [error], + } + + return { + "attempt_count": attempt_count, + "raw_model_output": raw_model_output, + "last_error": None, + } + + +def normalize_text_relations_node(state: TextRelationsState) -> dict[str, object]: + """Normalize raw LLM output into TextRelationSpec.""" + raw_model_output = state["raw_model_output"] + if raw_model_output is None: + return {} + + try: + text_relations = build_text_relation_spec( + scene_intake=state["scene_intake"], + model_output=raw_model_output, + ) + except ValueError as exc: + error = format_attempt_error("Text relations", state["attempt_count"], exc) + log.log_warning(error) + return { + "text_relations": None, + "last_error": error, + "errors": state["errors"] + [error], + } + + artifact_writer = WorkflowArtifactWriter( + state["output_root"], + TEXT_RELATIONS_STEP, + ) + artifact_writer.write_step_result(text_relations.to_manifest()) + return {"text_relations": text_relations, "last_error": None} diff --git a/embodichain/gen_sim/prompt2scene/workflows/text_relations/prompts.py b/embodichain/gen_sim/prompt2scene/workflows/text_relations/prompts.py new file mode 100644 index 000000000..a6f02e4f6 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/workflows/text_relations/prompts.py @@ -0,0 +1,55 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from typing import Any + +from embodichain.gen_sim.prompt2scene.prompts import render_prompt +from embodichain.gen_sim.prompt2scene.workflows.request import Prompt2SceneInput +from embodichain.gen_sim.prompt2scene.workflows.scene_intake.schema import ( + SceneIntakeSpec, +) + +__all__ = ["build_text_relation_messages"] + +TEXT_RELATIONS_PROMPT_NAME = "text_relations.yaml" + + +def build_text_relation_messages( + *, + request: Prompt2SceneInput, + scene_intake: SceneIntakeSpec, +) -> list[dict[str, Any]]: + """Build messages for explicit text spatial-relation extraction.""" + asset_names = "\n".join(f"- {asset.name}" for asset in scene_intake.assets) + return [ + { + "role": "system", + "content": render_prompt(TEXT_RELATIONS_PROMPT_NAME, prompt_key="system"), + }, + { + "role": "user", + "content": render_prompt( + TEXT_RELATIONS_PROMPT_NAME, + { + "asset_names": asset_names, + "text": request.text or "", + }, + prompt_key="user", + ), + }, + ] diff --git a/embodichain/gen_sim/prompt2scene/workflows/text_relations/schema.py b/embodichain/gen_sim/prompt2scene/workflows/text_relations/schema.py new file mode 100644 index 000000000..db2e513ff --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/workflows/text_relations/schema.py @@ -0,0 +1,164 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any + +from embodichain.gen_sim.prompt2scene.workflows.spatial import ( + GRID_VALUE_LIST, + RELATION_VALUE_LIST, +) + +__all__ = [ + "TEXT_RELATIONS_JSON_SCHEMA", + "TextObjectLayout", + "TextObjectRelation", + "TextRelationSpec", + "TextTableConstraint", +] + +TEXT_RELATIONS_JSON_SCHEMA: dict[str, Any] = { + "title": "TextRelationsOutput", + "type": "object", + "additionalProperties": False, + "properties": { + "object_relations": { + "type": "array", + "items": { + "type": "object", + "additionalProperties": False, + "properties": { + "subject": {"type": "string", "minLength": 1}, + "relation": { + "type": "string", + "enum": RELATION_VALUE_LIST, + }, + "object": {"type": "string", "minLength": 1}, + "evidence": {"type": "string", "minLength": 1}, + }, + "required": ["subject", "relation", "object", "evidence"], + }, + }, + "table_constraints": { + "type": "array", + "items": { + "type": "object", + "additionalProperties": False, + "properties": { + "asset": {"type": "string", "minLength": 1}, + "grid": { + "type": "string", + "enum": GRID_VALUE_LIST, + }, + "evidence": {"type": "string", "minLength": 1}, + }, + "required": ["asset", "grid", "evidence"], + }, + }, + "object_layouts": { + "type": "array", + "items": { + "type": "object", + "additionalProperties": False, + "properties": { + "asset": {"type": "string", "minLength": 1}, + "is_arbitrary_layout": {"type": "boolean"}, + "reason": {"type": "string", "minLength": 1}, + }, + "required": ["asset", "is_arbitrary_layout", "reason"], + }, + }, + }, + "required": ["object_relations", "table_constraints", "object_layouts"], +} + + +@dataclass(frozen=True) +class TextObjectRelation: + """Text-stated relation between two scene-intake asset groups.""" + + subject: str + relation: str + object: str + evidence: str + + def to_manifest(self) -> dict[str, str]: + """Convert the relation to JSON-safe data.""" + return { + "subject": self.subject, + "relation": self.relation, + "object": self.object, + "evidence": self.evidence, + } + + +@dataclass(frozen=True) +class TextTableConstraint: + """Text-stated table grid constraint for one asset group.""" + + asset: str + grid: str + evidence: str + + def to_manifest(self) -> dict[str, str]: + """Convert the table constraint to JSON-safe data.""" + return { + "asset": self.asset, + "grid": self.grid, + "evidence": self.evidence, + } + + +@dataclass(frozen=True) +class TextObjectLayout: + """Text-stated object support-pose constraint.""" + + asset: str + is_arbitrary_layout: bool + reason: str + + def to_manifest(self) -> dict[str, object]: + """Convert the layout constraint to JSON-safe data.""" + return { + "asset": self.asset, + "is_arbitrary_layout": self.is_arbitrary_layout, + "reason": self.reason, + } + + +@dataclass(frozen=True) +class TextRelationSpec: + """Spatial constraints explicitly extracted from a text prompt.""" + + source_text: str + object_relations: list[TextObjectRelation] = field(default_factory=list) + table_constraints: list[TextTableConstraint] = field(default_factory=list) + object_layouts: list[TextObjectLayout] = field(default_factory=list) + + def to_manifest(self) -> dict[str, object]: + """Convert the text relations to JSON-safe data.""" + return { + "source_text": self.source_text, + "object_relations": [ + relation.to_manifest() for relation in self.object_relations + ], + "table_constraints": [ + constraint.to_manifest() for constraint in self.table_constraints + ], + "object_layouts": [layout.to_manifest() for layout in self.object_layouts], + } diff --git a/embodichain/gen_sim/prompt2scene/workflows/text_relations/state.py b/embodichain/gen_sim/prompt2scene/workflows/text_relations/state.py new file mode 100644 index 000000000..b8dfa4c9f --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/workflows/text_relations/state.py @@ -0,0 +1,42 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from pathlib import Path +from typing import Any + +from embodichain.gen_sim.prompt2scene.workflows.request import Prompt2SceneInput +from embodichain.gen_sim.prompt2scene.workflows.scene_intake.schema import ( + SceneIntakeSpec, +) +from embodichain.gen_sim.prompt2scene.workflows.text_relations.schema import ( + TextRelationSpec, +) +from embodichain.gen_sim.prompt2scene.workflows.attempt_state import AttemptState + +__all__ = ["TextRelationsState"] + + +class TextRelationsState(AttemptState): + """LangGraph state for explicit text spatial-relation extraction.""" + + request: Prompt2SceneInput + scene_intake: SceneIntakeSpec + output_root: Path + messages: list[Any] + raw_model_output: dict[str, Any] | None + text_relations: TextRelationSpec | None diff --git a/embodichain/gen_sim/prompt2scene/workflows/text_relations/utils.py b/embodichain/gen_sim/prompt2scene/workflows/text_relations/utils.py new file mode 100644 index 000000000..58002713b --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/workflows/text_relations/utils.py @@ -0,0 +1,191 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from typing import Any + +from embodichain.gen_sim.prompt2scene.workflows.scene_intake.schema import ( + SceneIntakeSpec, +) +from embodichain.gen_sim.prompt2scene.workflows.spatial import ( + GRID_VALUES, + RELATION_VALUES, +) +from embodichain.gen_sim.prompt2scene.workflows.scene_intake.utils import ( + normalize_asset_name, +) +from embodichain.gen_sim.prompt2scene.workflows.text_relations.schema import ( + TextObjectLayout, + TextObjectRelation, + TextRelationSpec, + TextTableConstraint, +) + +__all__ = [ + "build_text_relation_spec", +] + + +def build_text_relation_spec( + *, + scene_intake: SceneIntakeSpec, + model_output: dict[str, Any], +) -> TextRelationSpec: + """Normalize raw LLM JSON into text relation constraints.""" + asset_names = {asset.name for asset in scene_intake.assets} + object_relations = _parse_object_relations( + model_output.get("object_relations"), + asset_names=asset_names, + ) + table_constraints = _parse_table_constraints( + model_output.get("table_constraints"), + asset_names=asset_names, + ) + object_layouts = _parse_object_layouts( + model_output.get("object_layouts"), + asset_names=asset_names, + ) + return TextRelationSpec( + source_text=scene_intake.input.text or "", + object_relations=object_relations, + table_constraints=table_constraints, + object_layouts=object_layouts, + ) + + +def _parse_object_relations( + raw_relations: Any, + *, + asset_names: set[str], +) -> list[TextObjectRelation]: + if not isinstance(raw_relations, list): + raise ValueError("text_relations.object_relations must be a list.") + relations: list[TextObjectRelation] = [] + seen: set[tuple[str, str, str]] = set() + for index, raw_relation in enumerate(raw_relations): + if not isinstance(raw_relation, dict): + raise ValueError( + f"text_relations.object_relations[{index}] must be an object." + ) + subject = _parse_asset_name(raw_relation.get("subject"), asset_names, index) + relation = str(raw_relation.get("relation") or "").strip() + object_name = _parse_asset_name(raw_relation.get("object"), asset_names, index) + evidence = str(raw_relation.get("evidence") or "").strip() + if relation not in RELATION_VALUES: + raise ValueError( + f"text_relations.object_relations[{index}].relation is invalid." + ) + if not evidence: + raise ValueError( + f"text_relations.object_relations[{index}].evidence is required." + ) + key = (subject, relation, object_name) + if key in seen: + continue + seen.add(key) + relations.append( + TextObjectRelation( + subject=subject, + relation=relation, + object=object_name, + evidence=evidence, + ) + ) + return relations + + +def _parse_table_constraints( + raw_constraints: Any, + *, + asset_names: set[str], +) -> list[TextTableConstraint]: + if not isinstance(raw_constraints, list): + raise ValueError("text_relations.table_constraints must be a list.") + constraints: list[TextTableConstraint] = [] + seen: set[tuple[str, str]] = set() + for index, raw_constraint in enumerate(raw_constraints): + if not isinstance(raw_constraint, dict): + raise ValueError( + f"text_relations.table_constraints[{index}] must be an object." + ) + asset = _parse_asset_name(raw_constraint.get("asset"), asset_names, index) + grid = str(raw_constraint.get("grid") or "").strip() + evidence = str(raw_constraint.get("evidence") or "").strip() + if grid not in GRID_VALUES: + raise ValueError( + f"text_relations.table_constraints[{index}].grid is invalid." + ) + if not evidence: + raise ValueError( + f"text_relations.table_constraints[{index}].evidence is required." + ) + key = (asset, grid) + if key in seen: + continue + seen.add(key) + constraints.append( + TextTableConstraint(asset=asset, grid=grid, evidence=evidence) + ) + return constraints + + +def _parse_object_layouts( + raw_layouts: Any, + *, + asset_names: set[str], +) -> list[TextObjectLayout]: + if not isinstance(raw_layouts, list): + raise ValueError("text_relations.object_layouts must be a list.") + layouts: list[TextObjectLayout] = [] + seen: set[str] = set() + for index, raw_layout in enumerate(raw_layouts): + if not isinstance(raw_layout, dict): + raise ValueError( + f"text_relations.object_layouts[{index}] must be an object." + ) + asset = _parse_asset_name(raw_layout.get("asset"), asset_names, index) + is_arbitrary_layout = raw_layout.get("is_arbitrary_layout") + reason = str(raw_layout.get("reason") or "").strip() + if not isinstance(is_arbitrary_layout, bool): + raise ValueError( + "text_relations.object_layouts" + f"[{index}].is_arbitrary_layout must be boolean." + ) + if not reason: + raise ValueError( + f"text_relations.object_layouts[{index}].reason is required." + ) + if asset in seen: + continue + seen.add(asset) + layouts.append( + TextObjectLayout( + asset=asset, + is_arbitrary_layout=is_arbitrary_layout, + reason=reason, + ) + ) + return layouts + + +def _parse_asset_name(raw_name: Any, asset_names: set[str], index: int) -> str: + name = normalize_asset_name(str(raw_name or "")) + if name not in asset_names: + raise ValueError( + f"text_relations item {index} references unknown scene asset: {name!r}." + ) + return name diff --git a/embodichain/gen_sim/prompt2scene/workflows/unified_scene/__init__.py b/embodichain/gen_sim/prompt2scene/workflows/unified_scene/__init__.py new file mode 100644 index 000000000..015c41510 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/workflows/unified_scene/__init__.py @@ -0,0 +1,19 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +__all__: list[str] = [] diff --git a/embodichain/gen_sim/prompt2scene/workflows/unified_scene/graph.py b/embodichain/gen_sim/prompt2scene/workflows/unified_scene/graph.py new file mode 100644 index 000000000..7431f0c0b --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/workflows/unified_scene/graph.py @@ -0,0 +1,97 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from pathlib import Path +from typing import Any + +from langgraph.graph import END, StateGraph + +from embodichain.gen_sim.prompt2scene.utils import log +from embodichain.gen_sim.prompt2scene.workflows.stage_errors import ( + format_result_missing_error, +) +from embodichain.gen_sim.prompt2scene.workflows.image_relations.schema import ( + ImageRelationSpec, +) +from embodichain.gen_sim.prompt2scene.workflows.scene_intake.schema import ( + SceneIntakeSpec, +) +from embodichain.gen_sim.prompt2scene.workflows.text_relations.schema import ( + TextRelationSpec, +) +from embodichain.gen_sim.prompt2scene.workflows.unified_scene.schema import ( + UnifiedSceneSpec, +) +from embodichain.gen_sim.prompt2scene.workflows.unified_scene.nodes import ( + build_unified_scene_node, +) +from embodichain.gen_sim.prompt2scene.workflows.unified_scene.state import ( + UnifiedSceneState, +) +from embodichain.gen_sim.prompt2scene.workflows.request import Prompt2SceneInput + +__all__ = ["build_unified_scene_graph", "run_unified_scene"] + + +def build_unified_scene_graph() -> Any: + """Build the fixed unified-scene assembly workflow.""" + graph = StateGraph(UnifiedSceneState) + graph.add_node("build_unified_scene", build_unified_scene_node) + graph.set_entry_point("build_unified_scene") + graph.add_edge("build_unified_scene", END) + return graph.compile() + + +def run_unified_scene( + request: Prompt2SceneInput, + *, + scene_intake: SceneIntakeSpec, + image_relations: ImageRelationSpec | None = None, + text_relations: TextRelationSpec | None = None, + output_root: Path, +) -> UnifiedSceneSpec: + """Run final unified-scene assembly for one prompt2scene request.""" + graph = build_unified_scene_graph() + result = graph.invoke( + { + "request": request, + "scene_intake": scene_intake, + "output_root": output_root, + "image_relations": image_relations, + "text_relations": text_relations, + "unified_scene": None, + "attempt_count": 0, + "max_attempts": 1, + "last_error": None, + "errors": [], + } + ) + + unified_scene = result.get("unified_scene") + if unified_scene is not None: + return unified_scene + + error = format_result_missing_error( + "Unified scene", + "UnifiedSceneSpec", + attempt_count=result.get("attempt_count", 0), + last_error=result.get("last_error"), + errors=result.get("errors", []), + ) + log.log_warning(error) + raise RuntimeError(error) diff --git a/embodichain/gen_sim/prompt2scene/workflows/unified_scene/nodes.py b/embodichain/gen_sim/prompt2scene/workflows/unified_scene/nodes.py new file mode 100644 index 000000000..5d65a737a --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/workflows/unified_scene/nodes.py @@ -0,0 +1,57 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from embodichain.gen_sim.prompt2scene.workflows.artifact_writer import ( + UNIFIED_SCENE_STEP, + WorkflowArtifactWriter, +) +from embodichain.gen_sim.prompt2scene.workflows.unified_scene.state import ( + UnifiedSceneState, +) +from embodichain.gen_sim.prompt2scene.workflows.unified_scene.utils import ( + build_unified_scene_from_image_relations, + build_unified_scene_from_text_relations, +) + +__all__ = ["build_unified_scene_node"] + + +def build_unified_scene_node(state: UnifiedSceneState) -> dict[str, object]: + """Assemble the final unified scene manifest.""" + scene_intake = state["scene_intake"] + image_relations = state.get("image_relations") + text_relations = state.get("text_relations") + + if image_relations is not None and image_relations.status == "ok": + unified_scene = build_unified_scene_from_image_relations( + scene_intake=scene_intake, + image_relations=image_relations, + ) + elif text_relations is not None: + unified_scene = build_unified_scene_from_text_relations( + scene_intake=scene_intake, + text_relations=text_relations, + ) + else: + raise ValueError("Unified scene requires image_relations or text_relations.") + + WorkflowArtifactWriter( + state["output_root"], + UNIFIED_SCENE_STEP, + ).write_step_result(unified_scene.to_manifest()) + return {"unified_scene": unified_scene} diff --git a/embodichain/gen_sim/prompt2scene/workflows/unified_scene/schema.py b/embodichain/gen_sim/prompt2scene/workflows/unified_scene/schema.py new file mode 100644 index 000000000..f3d13125d --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/workflows/unified_scene/schema.py @@ -0,0 +1,157 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any + +__all__ = [ + "UnifiedObject", + "UnifiedSceneSpec", + "UnifiedSpatial", + "UnifiedSpatialAnchor", + "UnifiedSpatialRelation", + "UnifiedTable", +] + + +@dataclass(frozen=True) +class UnifiedTable: + """Unified table/support object.""" + + id: str + name: str + description: str + complete_table_description: str + is_complete_visible_table: bool + class_candidate: list[str] + image_path: str | None = None + mesh_path: str | None = None + grid_cells: dict[str, list[str]] | None = None + + def to_manifest(self) -> dict[str, Any]: + """Convert the table to JSON-safe data.""" + return { + "id": self.id, + "name": self.name, + "description": self.description, + "complete_table_description": self.complete_table_description, + "is_complete_visible_table": self.is_complete_visible_table, + "class_candidate": list(self.class_candidate), + "image_path": self.image_path, + "mesh_path": self.mesh_path, + "grid_cells": self.grid_cells, + } + + +@dataclass(frozen=True) +class UnifiedObject: + """Unified object instance used by downstream scene generation.""" + + id: str + name: str + description: str + class_candidate: list[str] + grid: str | None = None + is_arbitrary_layout: bool = False + layout_reason: str = "" + image_path: str | None = None + mesh_path: str | None = None + + def to_manifest(self) -> dict[str, Any]: + """Convert the object to JSON-safe data.""" + return { + "id": self.id, + "name": self.name, + "description": self.description, + "class_candidate": list(self.class_candidate), + "grid": self.grid, + "is_arbitrary_layout": self.is_arbitrary_layout, + "layout_reason": self.layout_reason, + "image_path": self.image_path, + "mesh_path": self.mesh_path, + } + + +@dataclass(frozen=True) +class UnifiedSpatialAnchor: + """Spatial anchor used to infer a full table grid.""" + + object_id: str + grid: str + reason: str = "" + + def to_manifest(self) -> dict[str, str]: + """Convert the anchor to JSON-safe data.""" + return { + "object_id": self.object_id, + "grid": self.grid, + "reason": self.reason, + } + + +@dataclass(frozen=True) +class UnifiedSpatialRelation: + """Unified pairwise spatial relation between two objects.""" + + subject: str + relation: str + object: str + source: str + + def to_manifest(self) -> dict[str, str]: + """Convert the relation to JSON-safe data.""" + return { + "subject": self.subject, + "relation": self.relation, + "object": self.object, + "source": self.source, + } + + +@dataclass(frozen=True) +class UnifiedSpatial: + """Unified spatial relations for a scene.""" + + anchor: UnifiedSpatialAnchor | None = None + relations: list[UnifiedSpatialRelation] = field(default_factory=list) + + def to_manifest(self) -> dict[str, Any]: + """Convert the spatial record to JSON-safe data.""" + return { + "anchor": self.anchor.to_manifest() if self.anchor else None, + "relations": [relation.to_manifest() for relation in self.relations], + } + + +@dataclass(frozen=True) +class UnifiedSceneSpec: + """Unified scene representation consumed by downstream generation steps.""" + + input: dict[str, Any] + table: UnifiedTable + objects: list[UnifiedObject] + spatial: UnifiedSpatial + + def to_manifest(self) -> dict[str, Any]: + """Convert the unified scene to JSON-safe data.""" + return { + "input": dict(self.input), + "table": self.table.to_manifest(), + "objects": [obj.to_manifest() for obj in self.objects], + "spatial": self.spatial.to_manifest(), + } diff --git a/embodichain/gen_sim/prompt2scene/workflows/unified_scene/state.py b/embodichain/gen_sim/prompt2scene/workflows/unified_scene/state.py new file mode 100644 index 000000000..8152a6bf7 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/workflows/unified_scene/state.py @@ -0,0 +1,45 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from pathlib import Path +from typing import Any + +from embodichain.gen_sim.prompt2scene.workflows.image_relations.schema import ( + ImageRelationSpec, +) +from embodichain.gen_sim.prompt2scene.workflows.scene_intake.schema import ( + SceneIntakeSpec, +) +from embodichain.gen_sim.prompt2scene.workflows.text_relations.schema import ( + TextRelationSpec, +) +from embodichain.gen_sim.prompt2scene.workflows.attempt_state import AttemptState +from embodichain.gen_sim.prompt2scene.workflows.request import Prompt2SceneInput + +__all__ = ["UnifiedSceneState"] + + +class UnifiedSceneState(AttemptState): + """LangGraph state for unified scene assembly.""" + + request: Prompt2SceneInput + scene_intake: SceneIntakeSpec + output_root: Path + image_relations: ImageRelationSpec | None + text_relations: TextRelationSpec | None + unified_scene: Any | None diff --git a/embodichain/gen_sim/prompt2scene/workflows/unified_scene/utils.py b/embodichain/gen_sim/prompt2scene/workflows/unified_scene/utils.py new file mode 100644 index 000000000..e17b5e7b6 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/workflows/unified_scene/utils.py @@ -0,0 +1,332 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from collections import defaultdict +from typing import Any + +from embodichain.gen_sim.prompt2scene.workflows.image_relations.schema import ( + ImageAnchor, + ImageRelationSpec, +) +from embodichain.gen_sim.prompt2scene.workflows.spatial import ( + assign_grids_from_anchor_and_orders, + derive_relations_from_orders, + transitive_relation_closure, +) +from embodichain.gen_sim.prompt2scene.workflows.unified_scene.schema import ( + UnifiedObject, + UnifiedSceneSpec, + UnifiedSpatialAnchor, + UnifiedSpatialRelation, + UnifiedSpatial, + UnifiedTable, +) +from embodichain.gen_sim.prompt2scene.workflows.scene_intake.schema import ( + SceneIntakeAsset, + SceneIntakeSpec, +) +from embodichain.gen_sim.prompt2scene.workflows.text_relations.schema import ( + TextObjectLayout, + TextRelationSpec, +) + +__all__ = [ + "build_unified_object", + "build_unified_object_specs", + "build_unified_scene_from_image_relations", + "build_unified_scene_from_text_relations", + "build_unified_spatial_anchor", + "build_unified_table", + "grid_cells_from_objects", + "object_ids_by_name", + "relations_by_object_id", + "resolve_image_layout", + "resolve_text_layout", + "text_grids_by_object_id", +] + + +def build_unified_object_specs( + assets: list[SceneIntakeAsset], +) -> list[dict[str, Any]]: + """Expand scene-intake assets into unified object instance specs.""" + specs: list[dict[str, Any]] = [] + for asset in assets: + for index in range(asset.count): + specs.append( + { + "id": f"{asset.id}_{index}", + "name": asset.name, + "description": asset.description, + "class_candidate": list(asset.class_candidate), + } + ) + return specs + + +def object_ids_by_name(object_specs: list[dict[str, Any]]) -> dict[str, list[str]]: + """Group expanded object ids by object name.""" + grouped: dict[str, list[str]] = defaultdict(list) + for spec in object_specs: + grouped[str(spec["name"])].append(str(spec["id"])) + return dict(grouped) + + +def build_unified_table( + scene_intake: SceneIntakeSpec, + *, + grid_cells: dict[str, list[str]] | None = None, +) -> dict[str, Any]: + """Build a unified table record from scene intake.""" + return { + "id": scene_intake.table.id, + "name": scene_intake.table.name, + "description": scene_intake.table.description, + "complete_table_description": ( + scene_intake.table.complete_table_description + ), + "is_complete_visible_table": scene_intake.table.is_complete_visible_table, + "class_candidate": list(scene_intake.table.class_candidate), + "image_path": None, + "mesh_path": None, + "grid_cells": grid_cells, + } + + +def build_unified_spatial_anchor(anchor: ImageAnchor | None) -> dict[str, Any] | None: + """Convert the image anchor to a unified spatial anchor record.""" + if anchor is None: + return None + return { + "object_id": anchor.asset_id, + "grid": anchor.grid, + "reason": anchor.reason, + } + + +def build_unified_object( + *, + spec: dict[str, Any], + grid: str | None, + is_arbitrary_layout: bool, + layout_reason: str, +) -> dict[str, Any]: + """Build one unified object record.""" + return { + "id": spec["id"], + "name": spec["name"], + "description": spec["description"], + "class_candidate": list(spec["class_candidate"]), + "grid": grid, + "is_arbitrary_layout": is_arbitrary_layout, + "layout_reason": layout_reason, + "image_path": None, + "mesh_path": None, + } + + +def resolve_image_layout( + asset_id: str, + layout_by_id: dict[str, Any], +) -> tuple[bool, str]: + """Resolve an image asset's layout state.""" + layout = layout_by_id.get(asset_id) + if layout is None: + return False, "" + return bool(layout.is_arbitrary_layout), str(layout.reason) + + +def resolve_text_layout( + name: str, + layout_by_name: dict[str, TextObjectLayout], +) -> tuple[bool, str]: + """Resolve a text asset's layout state.""" + layout = layout_by_name.get(name) + if layout is None: + return False, "" + return bool(layout.is_arbitrary_layout), str(layout.reason) + + +def text_grids_by_object_id( + *, + text_relations: TextRelationSpec, + ids_by_name: dict[str, list[str]], +) -> dict[str, str | None]: + """Assign explicit text table constraints to object ids.""" + grids: dict[str, str | None] = {object_id: None for ids in ids_by_name.values() for object_id in ids} + for constraint in text_relations.table_constraints: + for object_id in ids_by_name.get(constraint.asset, []): + grids[object_id] = constraint.grid + return grids + + +def grid_cells_from_objects(objects: list[dict[str, Any]]) -> dict[str, list[str]] | None: + """Build table grid cell membership from unified objects.""" + grid_cells: dict[str, list[str]] = { + "center": [], + "front": [], + "back": [], + "left_center": [], + "right_center": [], + "left_front": [], + "right_front": [], + "left_back": [], + "right_back": [], + } + any_grid = False + for obj in objects: + grid = obj.get("grid") + if not grid: + continue + any_grid = True + grid_cells.setdefault(str(grid), []).append(str(obj["id"])) + return grid_cells if any_grid else None + + +def relations_by_object_id( + *, + text_relations: TextRelationSpec, + ids_by_name: dict[str, list[str]], +) -> list[dict[str, str]]: + """Expand text relations to object-id relations.""" + relations: list[dict[str, str]] = [] + for relation in text_relations.object_relations: + subjects = ids_by_name.get(relation.subject, []) + objects = ids_by_name.get(relation.object, []) + for subject in subjects: + for object_id in objects: + if subject == object_id: + continue + relations.append( + { + "subject": subject, + "relation": relation.relation, + "object": object_id, + "source": "input", + } + ) + return relations + + +def build_unified_scene_from_image_relations( + *, + scene_intake: SceneIntakeSpec, + image_relations: ImageRelationSpec, +) -> UnifiedSceneSpec: + """Build a unified scene from image relation outputs.""" + object_specs = build_unified_object_specs(scene_intake.assets) + anchor = build_unified_spatial_anchor(image_relations.anchor) + if anchor is None: + raise ValueError("Image unified scene requires an anchor.") + layout_by_id = { + layout.asset_id: layout for layout in image_relations.asset_layouts + } + objects = [] + for spec in object_specs: + is_arbitrary_layout, layout_reason = resolve_image_layout( + spec["id"], + layout_by_id, + ) + objects.append( + UnifiedObject( + **build_unified_object( + spec=spec, + grid=anchor["grid"] if spec["id"] == anchor["object_id"] else None, + is_arbitrary_layout=is_arbitrary_layout, + layout_reason=layout_reason, + ) + ) + ) + relations = [ + UnifiedSpatialRelation(**relation) + for relation in derive_relations_from_orders( + x_order=image_relations.x_order, + y_order=image_relations.y_order, + ) + ] + return UnifiedSceneSpec( + input=scene_intake.input.to_manifest(), + table=UnifiedTable( + **build_unified_table( + scene_intake, + grid_cells=grid_cells_from_objects( + [object_.to_manifest() for object_ in objects] + ), + ) + ), + objects=objects, + spatial=UnifiedSpatial( + anchor=UnifiedSpatialAnchor(**anchor), + relations=relations, + ), + ) + + +def build_unified_scene_from_text_relations( + *, + scene_intake: SceneIntakeSpec, + text_relations: TextRelationSpec, +) -> UnifiedSceneSpec: + """Build a unified scene from text relation outputs.""" + object_specs = build_unified_object_specs(scene_intake.assets) + ids_by_name = object_ids_by_name(object_specs) + grid_by_id = text_grids_by_object_id( + text_relations=text_relations, + ids_by_name=ids_by_name, + ) + layout_by_name = { + layout.asset: layout for layout in text_relations.object_layouts + } + objects = [] + for spec in object_specs: + is_arbitrary_layout, layout_reason = resolve_text_layout( + spec["name"], + layout_by_name, + ) + objects.append( + UnifiedObject( + **build_unified_object( + spec=spec, + grid=grid_by_id.get(spec["id"]), + is_arbitrary_layout=is_arbitrary_layout, + layout_reason=layout_reason, + ) + ) + ) + relations = [ + UnifiedSpatialRelation(**relation) + for relation in transitive_relation_closure( + relations_by_object_id( + text_relations=text_relations, + ids_by_name=ids_by_name, + ) + ) + ] + return UnifiedSceneSpec( + input=scene_intake.input.to_manifest(), + table=UnifiedTable( + **build_unified_table( + scene_intake, + grid_cells=grid_cells_from_objects( + [object_.to_manifest() for object_ in objects] + ), + ) + ), + objects=objects, + spatial=UnifiedSpatial(anchor=None, relations=relations), + ) diff --git a/embodichain/gen_sim/prompt2scene/workflows/unified_scene_gen/__init__.py b/embodichain/gen_sim/prompt2scene/workflows/unified_scene_gen/__init__.py new file mode 100644 index 000000000..ac849443e --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/workflows/unified_scene_gen/__init__.py @@ -0,0 +1,27 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from embodichain.gen_sim.prompt2scene.workflows.unified_scene_gen.graph import ( + build_unified_scene_gen_graph, + run_unified_scene_gen, +) + +__all__ = [ + "build_unified_scene_gen_graph", + "run_unified_scene_gen", +] diff --git a/embodichain/gen_sim/prompt2scene/workflows/unified_scene_gen/graph.py b/embodichain/gen_sim/prompt2scene/workflows/unified_scene_gen/graph.py new file mode 100644 index 000000000..5d542b392 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/workflows/unified_scene_gen/graph.py @@ -0,0 +1,106 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from pathlib import Path +from typing import Any + +from langgraph.graph import END, StateGraph + +from embodichain.gen_sim.prompt2scene.llms import build_chat_model +from embodichain.gen_sim.prompt2scene.llms.config import OpenAICompatibleLLMCfg +from embodichain.gen_sim.prompt2scene.workflows.unified_scene_gen.nodes import ( + fit_image_table_to_clutter_node, + fit_text_table_to_clutter_node, + generate_image_assets_node, + generate_text_assets_node, + generate_text_clutter_layout_node, + load_unified_scene_input_kind_node, +) +from embodichain.gen_sim.prompt2scene.workflows.unified_scene_gen.state import ( + UnifiedSceneGenState, +) +__all__ = [ + "build_unified_scene_gen_graph", + "route_after_load_input_kind", + "run_unified_scene_gen", +] + + +def route_after_load_input_kind(state: UnifiedSceneGenState) -> str: + """Route unified-scene generation by the original input kind.""" + input_kind = state["input_kind"] + if input_kind == "text": + return "generate_text_assets" + if input_kind == "image": + return "generate_image_assets" + raise ValueError(f"Unsupported unified-scene input_kind: {input_kind!r}.") + + +def build_unified_scene_gen_graph() -> Any: + """Build the unified-scene generation graph.""" + graph = StateGraph(UnifiedSceneGenState) + graph.add_node("load_unified_scene_input_kind", load_unified_scene_input_kind_node) + graph.add_node("generate_text_assets", generate_text_assets_node) + graph.add_node("generate_text_clutter_layout", generate_text_clutter_layout_node) + graph.add_node("fit_text_table_to_clutter", fit_text_table_to_clutter_node) + graph.add_node("generate_image_assets", generate_image_assets_node) + graph.add_node("fit_image_table_to_clutter", fit_image_table_to_clutter_node) + + graph.set_entry_point("load_unified_scene_input_kind") + graph.add_conditional_edges( + "load_unified_scene_input_kind", + route_after_load_input_kind, + { + "generate_text_assets": "generate_text_assets", + "generate_image_assets": "generate_image_assets", + }, + ) + graph.add_edge("generate_text_assets", "generate_text_clutter_layout") + graph.add_edge("generate_text_clutter_layout", "fit_text_table_to_clutter") + graph.add_edge("fit_text_table_to_clutter", END) + graph.add_edge("generate_image_assets", "fit_image_table_to_clutter") + graph.add_edge("fit_image_table_to_clutter", END) + return graph.compile() + + +def run_unified_scene_gen( + output_root: Path, + *, + unified_scene_result_path: Path | None = None, + llm_cfg: OpenAICompatibleLLMCfg | None = None, +) -> UnifiedSceneGenState: + """Run downstream generation routing from a unified-scene result.""" + llm = build_chat_model(llm_cfg) if llm_cfg is not None else None + initial_state: UnifiedSceneGenState = { + "output_root": output_root, + "unified_scene_result_path": unified_scene_result_path, + "llm": llm, + "unified_scene": None, + "input_kind": None, + "table_result": None, + "text_object_results": [], + "text_clutter_settle_result": None, + "image_objects_layout_result": None, + "table_fit_result": None, + "generation_status": None, + "attempt_count": 0, + "max_attempts": 1, + "last_error": None, + "errors": [], + } + return build_unified_scene_gen_graph().invoke(initial_state) diff --git a/embodichain/gen_sim/prompt2scene/workflows/unified_scene_gen/nodes.py b/embodichain/gen_sim/prompt2scene/workflows/unified_scene_gen/nodes.py new file mode 100644 index 000000000..e12e41f12 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/workflows/unified_scene_gen/nodes.py @@ -0,0 +1,392 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +import json + +from embodichain.gen_sim.prompt2scene.utils.log import log_info +from embodichain.gen_sim.prompt2scene.utils.io import write_json +from embodichain.gen_sim.prompt2scene.workflows.unified_scene_gen.state import ( + UnifiedSceneGenState, +) +from embodichain.gen_sim.prompt2scene.workflows.artifact_writer import ( + UNIFIED_SCENE_GEN_STEP, + UNIFIED_SCENE_STEP, + WorkflowArtifactWriter, +) +from embodichain.gen_sim.prompt2scene.agent_tools.tools.text_asset_generation import ( + generate_text_object_assets, + generate_text_table_asset, +) +from embodichain.gen_sim.prompt2scene.agent_tools.tools.text_scene_metric_scale import ( + estimate_text_scene_metric_scale, +) +from embodichain.gen_sim.prompt2scene.agent_tools.tools.text_clutter_layout import ( + generate_text_clutter_layout, +) +from embodichain.gen_sim.prompt2scene.agent_tools.tools.table_fit_scene import ( + fit_image_scene_table, + fit_text_scene_table, +) +from embodichain.gen_sim.prompt2scene.agent_tools.tools.image_scene_asset_generation import ( + generate_image_scene_assets, +) +from embodichain.gen_sim.prompt2scene.workflows.unified_scene_gen.paths import ( + UnifiedScenePaths, +) +from embodichain.gen_sim.prompt2scene.workflows.unified_scene_gen.prompts import ( + build_text_metric_scale_messages, +) +from embodichain.gen_sim.prompt2scene.workflows.unified_scene_gen.schema import ( + IMAGE_METRIC_SCALE_JSON_SCHEMA, +) +from embodichain.gen_sim.prompt2scene.workflows.unified_scene_gen.scene_update import ( + update_unified_scene, +) + +__all__ = [ + "fit_image_table_to_clutter_node", + "fit_text_table_to_clutter_node", + "generate_image_assets_node", + "generate_text_assets_node", + "generate_text_clutter_layout_node", + "load_unified_scene_input_kind_node", +] + + +def load_unified_scene_input_kind_node( + state: UnifiedSceneGenState, +) -> dict[str, object]: + """Load unified-scene output and determine the generation route.""" + paths = UnifiedScenePaths(state["output_root"]) + result_path = paths.resolve_scene_result(state["unified_scene_result_path"]) + if not result_path.is_file(): + raise FileNotFoundError(f"Unified scene result not found: {result_path}") + + with result_path.open("r", encoding="utf-8") as f: + unified_scene = json.load(f) + if not isinstance(unified_scene, dict): + raise ValueError("Unified scene result must be a JSON object.") + + input_record = unified_scene.get("input") + if not isinstance(input_record, dict): + raise ValueError("Unified scene result requires input object.") + + input_kind = str(input_record.get("input_kind") or "").strip() + if input_kind not in {"text", "image"}: + raise ValueError( + "Unified scene input.input_kind must be 'text' or 'image', " + f"got {input_kind!r}." + ) + + return { + "unified_scene_result_path": result_path, + "unified_scene": unified_scene, + "input_kind": input_kind, + } + + +def generate_text_assets_node( + state: UnifiedSceneGenState, +) -> dict[str, object]: + """Generate images, RGBA cutouts, geometry, and sim-ready GLBs for a + text-origin unified scene. + """ + unified_scene = state["unified_scene"] + if unified_scene is None: + return {"generation_status": "no_unified_scene"} + + paths = UnifiedScenePaths(state["output_root"]) + output_root = paths.output_root + image_gen_dir, glb_gen_dir, debug_dir = paths.prepare_generation_dirs() + log_info( + "generate_text_assets started " + f"output_dir={output_root / UNIFIED_SCENE_GEN_STEP}" + ) + + table_spec = unified_scene.get("table") or {} + table_result = generate_text_table_asset( + table_spec=table_spec, + image_gen_dir=image_gen_dir, + glb_gen_dir=glb_gen_dir, + debug_dir=debug_dir, + ) + + object_specs = unified_scene.get("objects") or [] + object_results = generate_text_object_assets( + object_specs=object_specs, + image_gen_dir=image_gen_dir, + glb_gen_dir=glb_gen_dir, + debug_dir=debug_dir, + ) + metric_prompt_objects = [ + { + "object_id": str(obj.get("id", "")), + "object_name": str(obj.get("name", "")), + "object_description": str(obj.get("description", "")), + } + for obj in object_results + ] + user_text = str((unified_scene.get("input") or {}).get("text") or "") + text_metric_scale_result = estimate_text_scene_metric_scale( + object_results=object_results, + user_text=user_text, + messages=build_text_metric_scale_messages( + user_text=user_text, + objects_json=metric_prompt_objects, + ), + schema=IMAGE_METRIC_SCALE_JSON_SCHEMA, + output_dir=glb_gen_dir / "metric_scale", + output_root=output_root, + llm=state.get("llm"), + step_name=UNIFIED_SCENE_STEP, + ) + + result_path = paths.resolve_scene_result(state["unified_scene_result_path"]) + update_unified_scene(unified_scene, table_result, object_results, output_root) + write_json(result_path, unified_scene) + WorkflowArtifactWriter(output_root, UNIFIED_SCENE_GEN_STEP).write_step_result( + { + "table": table_result, + "objects": object_results, + "text_metric_scale": text_metric_scale_result, + "generation_status": "ok", + } + ) + log_info( + "generate_text_assets completed " + f"table_status={table_result.get('status')} " + f"object_count={len(object_results)}" + ) + + return { + "unified_scene": unified_scene, + "table_result": table_result, + "text_object_results": object_results, + "generation_status": "ok", + } + + +def generate_image_assets_node(state: UnifiedSceneGenState) -> dict[str, object]: + """Generate table assets and layout-aware object GLBs for image input. + + Table/support and objects are generated in one multi-object call from the + original image and existing segmentation masks. + """ + unified_scene = state["unified_scene"] + if unified_scene is None: + return {"generation_status": "no_unified_scene"} + + paths = UnifiedScenePaths(state["output_root"]) + output_root = paths.output_root + image_gen_dir, glb_gen_dir, debug_dir = paths.prepare_generation_dirs() + log_info( + "generate_image_assets started " + f"output_dir={output_root / UNIFIED_SCENE_GEN_STEP}" + ) + + segments_path = paths.image_segments_result + if not segments_path.is_file(): + raise FileNotFoundError( + f"Image segments result not found: {segments_path}" + ) + with segments_path.open("r", encoding="utf-8") as _f: + segments_data = json.load(_f) + if not isinstance(segments_data, dict): + raise ValueError("Image segments result must be a JSON object.") + + table_spec = unified_scene.get("table") or {} + # Image input uses the segmented table/support mask in the multi-object + # SAM3D call below. Text table generation belongs to the text branch. + object_specs = unified_scene.get("objects") or [] + object_layout_result = generate_image_scene_assets( + object_specs=object_specs, + table_spec=table_spec, + spatial_relations=(unified_scene.get("spatial") or {}).get("relations", []), + segments_data=segments_data, + image_gen_dir=image_gen_dir, + glb_gen_dir=glb_gen_dir, + debug_dir=debug_dir, + output_root=output_root, + llm=state.get("llm"), + ) + table_result = object_layout_result.get("table") or { + "id": str(table_spec.get("id", "table")), + "name": str(table_spec.get("name", "table")), + "status": "missing_table_generation", + } + object_results = object_layout_result.get("objects") or [] + generation_status = str(object_layout_result.get("status", "failed")) + if table_result.get("status") != "ok": + generation_status = str(table_result.get("status") or generation_status) + result_path = paths.resolve_scene_result(state["unified_scene_result_path"]) + update_unified_scene(unified_scene, table_result, object_results, output_root) + write_json(result_path, unified_scene) + WorkflowArtifactWriter(output_root, UNIFIED_SCENE_GEN_STEP).write_step_result( + { + "table": table_result, + "objects_layout": object_layout_result, + "objects": object_results, + "table_fit_to_clutter": None, + "generation_status": generation_status, + } + ) + log_info(f"generate_image_assets completed status={generation_status}") + + return { + "unified_scene": unified_scene, + "table_result": table_result, + "text_object_results": object_results, + "image_objects_layout_result": object_layout_result, + "generation_status": generation_status, + } + + +def fit_image_table_to_clutter_node(state: UnifiedSceneGenState) -> dict[str, object]: + """Resize the final table to fit the aligned image-object clutter.""" + if state.get("input_kind") != "image": + return {} + + paths = UnifiedScenePaths(state["output_root"]) + output_root = paths.output_root + output_dir = paths.table_fit_dir + output_dir.mkdir(parents=True, exist_ok=True) + log_info(f"fit_image_table_to_clutter started output_dir={output_dir}") + layout_result = dict(state.get("image_objects_layout_result") or {}) + table_fit_result = fit_image_scene_table( + layout_result=layout_result, + fallback_table_result=state.get("table_result"), + output_root=output_root, + output_dir=output_dir, + ) + layout_result["table_fit_to_clutter"] = table_fit_result + WorkflowArtifactWriter(output_root, UNIFIED_SCENE_GEN_STEP).write_step_result( + { + "table": state.get("table_result"), + "objects_layout": layout_result, + "objects": state.get("text_object_results") or [], + "table_fit_to_clutter": table_fit_result, + "generation_status": state.get("generation_status"), + } + ) + log_info( + f"fit_image_table_to_clutter completed status={table_fit_result.get('status')}" + ) + return { + "image_objects_layout_result": layout_result, + "table_fit_result": table_fit_result, + } + + +def generate_text_clutter_layout_node( + state: UnifiedSceneGenState, +) -> dict[str, object]: + """Scale text objects to real-world size, gravity-settle, centre at origin. + + Produces per-object settled GLBs and 2D AABB footprints for downstream + spatial layout optimisation and table fitting. + """ + if state.get("input_kind") != "text": + return {} + + paths = UnifiedScenePaths(state["output_root"]) + output_root = paths.output_root + output_dir = paths.text_clutter_dir + output_dir.mkdir(parents=True, exist_ok=True) + log_info(f"generate_text_clutter_layout started output_dir={output_dir}") + + text_object_results = state.get("text_object_results") or [] + if not text_object_results: + return { + "text_clutter_settle_result": { + "status": "skipped", + "reason": "no_text_objects", + } + } + + unified_scene = state.get("unified_scene") or {} + spatial_data = unified_scene.get("spatial") or {} + spatial_relations = spatial_data.get("relations", []) + table_constraints = spatial_data.get("table_constraints", []) + + settle_result = generate_text_clutter_layout( + object_results=text_object_results, + spatial_relations=spatial_relations, + table_constraints=table_constraints, + output_dir=output_dir, + output_root=output_root, + ) + WorkflowArtifactWriter(output_root, UNIFIED_SCENE_GEN_STEP).write_step_result( + { + "table": state.get("table_result"), + "objects": text_object_results, + "text_clutter_settle": settle_result, + "generation_status": state.get("generation_status"), + } + ) + log_info( + f"generate_text_clutter_layout completed status={settle_result.get('status')}" + ) + return { + "text_clutter_settle_result": settle_result, + } + + +def fit_text_table_to_clutter_node( + state: UnifiedSceneGenState, +) -> dict[str, object]: + """Resize the text-scene table to fit the laid-out clutter footprint.""" + if state.get("input_kind") != "text": + return {} + + paths = UnifiedScenePaths(state["output_root"]) + output_root = paths.output_root + table_result = state.get("table_result") + settle_result = state.get("text_clutter_settle_result") + + if table_result is None or settle_result is None: + return { + "table_fit_result": { + "status": "skipped", + "reason": "missing_table_or_settle_result", + } + } + + output_dir = paths.table_fit_dir + output_dir.mkdir(parents=True, exist_ok=True) + log_info(f"fit_text_table_to_clutter started output_dir={output_dir}") + table_fit_result = fit_text_scene_table( + table_result=table_result, + clutter_layout_result=settle_result, + output_root=output_root, + output_dir=output_dir, + ) + WorkflowArtifactWriter(output_root, UNIFIED_SCENE_GEN_STEP).write_step_result( + { + "table": table_result, + "objects": state.get("text_object_results") or [], + "text_clutter_settle": settle_result, + "table_fit_to_clutter": table_fit_result, + "generation_status": state.get("generation_status"), + } + ) + log_info( + f"fit_text_table_to_clutter completed status={table_fit_result.get('status')}" + ) + return { + "table_fit_result": table_fit_result, + } diff --git a/embodichain/gen_sim/prompt2scene/workflows/unified_scene_gen/paths.py b/embodichain/gen_sim/prompt2scene/workflows/unified_scene_gen/paths.py new file mode 100644 index 000000000..c4af80541 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/workflows/unified_scene_gen/paths.py @@ -0,0 +1,102 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +from embodichain.gen_sim.prompt2scene.workflows.artifact_writer import ( + IMAGE_SEGMENTS_STEP, + STEP_RESULT_FILENAME, + UNIFIED_SCENE_GEN_STEP, + UNIFIED_SCENE_STEP, +) + +__all__ = ["UnifiedScenePaths", "resolve_generated_path"] + + +def resolve_generated_path(value: Any, output_root: Path) -> Path: + """Resolve an absolute or output-root-relative generated artifact path.""" + if not value: + return Path() + path = Path(str(value)).expanduser() + if path.is_absolute(): + return path.resolve() + return (output_root.expanduser().resolve() / path).resolve() + + +@dataclass(frozen=True) +class UnifiedScenePaths: + """High-level paths owned by the unified-scene generation workflow.""" + + output_root: Path + + def __post_init__(self) -> None: + object.__setattr__( + self, + "output_root", + self.output_root.expanduser().resolve(), + ) + + @property + def workflow_root(self) -> Path: + return self.output_root / UNIFIED_SCENE_GEN_STEP + + @property + def image_gen_dir(self) -> Path: + return self.workflow_root / "image_gen" + + @property + def glb_gen_dir(self) -> Path: + return self.workflow_root / "glb_gen" + + @property + def debug_dir(self) -> Path: + return self.workflow_root / "debug" + + @property + def text_clutter_dir(self) -> Path: + return self.glb_gen_dir / "text_clutter_settled" + + @property + def table_fit_dir(self) -> Path: + return self.glb_gen_dir / "table_fit_to_clutter" + + @property + def image_segments_result(self) -> Path: + return self.output_root / IMAGE_SEGMENTS_STEP / STEP_RESULT_FILENAME + + def prepare_generation_dirs(self) -> tuple[Path, Path, Path]: + """Create and return the workflow's high-level generation directories.""" + directories = (self.image_gen_dir, self.glb_gen_dir, self.debug_dir) + for directory in directories: + directory.mkdir(parents=True, exist_ok=True) + return directories + + def resolve_scene_result(self, explicit_path: Path | None) -> Path: + """Resolve the unified-scene result produced by the preceding workflow.""" + if explicit_path is not None: + return explicit_path.expanduser().resolve() + + scene_dir = self.output_root / UNIFIED_SCENE_STEP + result_path = scene_dir / STEP_RESULT_FILENAME + if result_path.is_file(): + return result_path + + legacy_path = scene_dir / "results.json" + return legacy_path if legacy_path.is_file() else result_path diff --git a/embodichain/gen_sim/prompt2scene/workflows/unified_scene_gen/prompts.py b/embodichain/gen_sim/prompt2scene/workflows/unified_scene_gen/prompts.py new file mode 100644 index 000000000..1543acfb6 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/workflows/unified_scene_gen/prompts.py @@ -0,0 +1,141 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +import json +from pathlib import Path +from typing import Any + +from embodichain.gen_sim.prompt2scene.prompts import render_prompt +from embodichain.gen_sim.prompt2scene.utils.io import image_to_data_url + +__all__ = [ + "build_image_metric_scale_messages", + "build_text_metric_scale_messages", + "build_up_down_flip_check_messages", +] + +UNIFIED_SCENE_GEN_PROMPT_NAME = "unified_scene_gen.yaml" + + +def build_image_metric_scale_messages( + *, + bbox_name_image_path: Path, + objects_json: list[dict[str, Any]], +) -> list[dict[str, Any]]: + """Build messages for image-scene object metric scale estimation.""" + return [ + { + "role": "system", + "content": render_prompt( + UNIFIED_SCENE_GEN_PROMPT_NAME, + prompt_key="image_metric_scale_system", + ), + }, + { + "role": "user", + "content": [ + { + "type": "text", + "text": render_prompt( + UNIFIED_SCENE_GEN_PROMPT_NAME, + { + "objects_json": json.dumps( + objects_json, + ensure_ascii=False, + indent=2, + ), + }, + prompt_key="image_metric_scale_user", + ), + }, + { + "type": "image_url", + "image_url": {"url": image_to_data_url(bbox_name_image_path)}, + }, + ], + }, + ] + + +def build_text_metric_scale_messages( + *, + user_text: str, + objects_json: list[dict[str, Any]], +) -> list[dict[str, Any]]: + """Build messages for text-scene object metric scale estimation.""" + return [ + { + "role": "system", + "content": render_prompt( + UNIFIED_SCENE_GEN_PROMPT_NAME, + prompt_key="text_metric_scale_system", + ), + }, + { + "role": "user", + "content": render_prompt( + UNIFIED_SCENE_GEN_PROMPT_NAME, + { + "user_text": user_text, + "objects_json": json.dumps( + objects_json, + ensure_ascii=False, + indent=2, + ), + }, + prompt_key="text_metric_scale_user", + ), + }, + ] + + +def build_up_down_flip_check_messages( + *, + original_image_path: Path, + comparison_image_path: Path, +) -> list[dict[str, Any]]: + """Build messages for VLM support-normal up/down flip verification.""" + return [ + { + "role": "system", + "content": render_prompt( + UNIFIED_SCENE_GEN_PROMPT_NAME, + prompt_key="up_down_flip_check_system", + ), + }, + { + "role": "user", + "content": [ + { + "type": "text", + "text": render_prompt( + UNIFIED_SCENE_GEN_PROMPT_NAME, + prompt_key="up_down_flip_check_user", + ), + }, + { + "type": "image_url", + "image_url": {"url": image_to_data_url(original_image_path)}, + }, + { + "type": "image_url", + "image_url": {"url": image_to_data_url(comparison_image_path)}, + }, + ], + }, + ] diff --git a/embodichain/gen_sim/prompt2scene/workflows/unified_scene_gen/scene_update.py b/embodichain/gen_sim/prompt2scene/workflows/unified_scene_gen/scene_update.py new file mode 100644 index 000000000..2276e559d --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/workflows/unified_scene_gen/scene_update.py @@ -0,0 +1,76 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from pathlib import Path +from typing import Any + +from embodichain.gen_sim.prompt2scene.utils.io import relative_path + +__all__ = ["update_unified_scene"] + + +def update_unified_scene( + unified_scene: dict[str, Any], + table_result: dict[str, Any], + object_results: list[dict[str, Any]], + output_root: Path, +) -> None: + """Write generated asset references back into a unified-scene payload.""" + table = unified_scene.setdefault("table", {}) + metadata_keys = ( + "table_asset_source", + "support_normal_source", + "is_complete_visible_table", + "complete_table_description", + ) + path_keys = ( + "image_path", + "raw_geometry_path", + "support_reference_geometry_path", + "generated_table_raw_geometry_path", + "transformed_geometry_path", + "simready_geometry_path", + "aligned_geometry_path", + "mesh_path", + ) + for key in metadata_keys: + if key in table_result: + table[key] = table_result[key] + for key in path_keys: + if table_result.get(key): + table[key] = relative_path(table_result[key], output_root) + + objects_by_id = { + str(item.get("id", "")): item + for item in unified_scene.setdefault("objects", []) + if isinstance(item, dict) + } + for result in object_results: + target = objects_by_id.get(str(result.get("id", ""))) + if target is None: + continue + for key in ("image_path", "mesh_path", "aligned_geometry_path"): + if result.get(key): + target[key] = relative_path(result[key], output_root) + metric_scale = result.get("metric_scale") + if isinstance(metric_scale, dict): + target["metric_scale"] = { + key: value + for key, value in metric_scale.items() + if key not in {"result_path", "raw_model_output_path"} + } diff --git a/embodichain/gen_sim/prompt2scene/workflows/unified_scene_gen/schema.py b/embodichain/gen_sim/prompt2scene/workflows/unified_scene_gen/schema.py new file mode 100644 index 000000000..b22fcebba --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/workflows/unified_scene_gen/schema.py @@ -0,0 +1,71 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from typing import Any + +__all__ = [ + "IMAGE_METRIC_SCALE_JSON_SCHEMA", + "UP_DOWN_FLIP_CHECK_JSON_SCHEMA", +] + +UP_DOWN_FLIP_CHECK_JSON_SCHEMA: dict[str, Any] = { + "title": "AlignedUpDownFlipCheckOutput", + "type": "object", + "additionalProperties": False, + "properties": { + "selected_number": {"type": "integer", "enum": [1, 2]}, + "confidence": {"type": "number", "minimum": 0.0, "maximum": 1.0}, + "reason": {"type": "string"}, + }, + "required": ["selected_number", "confidence", "reason"], +} + +IMAGE_METRIC_SCALE_JSON_SCHEMA: dict[str, Any] = { + "title": "ImageMetricScaleEstimate", + "type": "object", + "additionalProperties": False, + "properties": { + "object_scales": { + "type": "array", + "items": { + "type": "object", + "additionalProperties": False, + "properties": { + "object_id": {"type": "string"}, + "bbox_dims_cm": { + "type": "array", + "minItems": 3, + "maxItems": 3, + "items": { + "type": "number", + "minimum": 1.0e-6, + }, + }, + "confidence": { + "type": "number", + "minimum": 0.0, + "maximum": 1.0, + }, + "reason": {"type": "string"}, + }, + "required": ["object_id", "bbox_dims_cm", "confidence", "reason"], + }, + }, + }, + "required": ["object_scales"], +} diff --git a/embodichain/gen_sim/prompt2scene/workflows/unified_scene_gen/state.py b/embodichain/gen_sim/prompt2scene/workflows/unified_scene_gen/state.py new file mode 100644 index 000000000..122835160 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/workflows/unified_scene_gen/state.py @@ -0,0 +1,40 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from pathlib import Path +from typing import Any + +from embodichain.gen_sim.prompt2scene.workflows.attempt_state import AttemptState + +__all__ = ["UnifiedSceneGenState"] + + +class UnifiedSceneGenState(AttemptState): + """LangGraph state for downstream unified-scene generation.""" + + output_root: Path + unified_scene_result_path: Path | None + llm: Any | None + unified_scene: dict[str, Any] | None + input_kind: str | None + table_result: dict[str, Any] | None + text_object_results: list[dict[str, Any]] + text_clutter_settle_result: dict[str, Any] | None + image_objects_layout_result: dict[str, Any] | None + table_fit_result: dict[str, Any] | None + generation_status: str | None From 625802fab67d624ad621ab475878c6e9d3938bd4 Mon Sep 17 00:00:00 2001 From: Muzi Wong <178915912+MuziWong@users.noreply.github.com> Date: Mon, 29 Jun 2026 18:32:22 +0800 Subject: [PATCH 2/7] Fixed gym export bug: wrong object description; --- .../gen_sim/prompt2scene/agent_tools/tools/gym_export.py | 4 ++-- .../agent_tools/tools/image_scene_asset_generation.py | 1 + .../prompt2scene/agent_tools/tools/text_asset_generation.py | 1 + 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/tools/gym_export.py b/embodichain/gen_sim/prompt2scene/agent_tools/tools/gym_export.py index 9f3c638f5..0dcd67180 100644 --- a/embodichain/gen_sim/prompt2scene/agent_tools/tools/gym_export.py +++ b/embodichain/gen_sim/prompt2scene/agent_tools/tools/gym_export.py @@ -183,8 +183,8 @@ def export_gym_config( oid = str(obj.get("id", "")) if oid: object_meta_by_id[oid] = { - "description": str(obj.get("description", "")).strip(), - "name": str(obj.get("name", "")).strip(), + "description": str(obj.get("description") or "").strip(), + "name": str(obj.get("name") or "").strip(), } table_info = step_result.get("table") or {} diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/tools/image_scene_asset_generation.py b/embodichain/gen_sim/prompt2scene/agent_tools/tools/image_scene_asset_generation.py index 2275c40fa..5df5984a4 100644 --- a/embodichain/gen_sim/prompt2scene/agent_tools/tools/image_scene_asset_generation.py +++ b/embodichain/gen_sim/prompt2scene/agent_tools/tools/image_scene_asset_generation.py @@ -470,6 +470,7 @@ def generate_image_scene_assets( object_fields = ( "id", "name", + "description", "status", "image_path", "mesh_path", diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/tools/text_asset_generation.py b/embodichain/gen_sim/prompt2scene/agent_tools/tools/text_asset_generation.py index 1beb76039..b0d4a0f72 100644 --- a/embodichain/gen_sim/prompt2scene/agent_tools/tools/text_asset_generation.py +++ b/embodichain/gen_sim/prompt2scene/agent_tools/tools/text_asset_generation.py @@ -162,6 +162,7 @@ def generate_text_object_asset( return { "id": object_id, "name": object_name, + "description": description, "status": status, "image_path": image_path, "raw_geometry_path": raw_geometry_path, From 68adf9a1e11236fdd7d4014eefddf56572eac98a Mon Sep 17 00:00:00 2001 From: Muzi Wong <178915912+MuziWong@users.noreply.github.com> Date: Mon, 29 Jun 2026 19:49:30 +0800 Subject: [PATCH 3/7] 1. Update prompt in scene_intake; 2. VLM judge percentage of the object clutter when the input tabletop seems to be a complete one; --- .../table_clutter_fit_manager/manager.py | 15 ++++++- .../tools/image_scene_asset_generation.py | 1 + .../agent_tools/tools/table_fit_scene.py | 2 + .../tools/text_asset_generation.py | 1 + .../prompts/data/scene_intake.yaml | 45 ++++++++++++++++--- .../workflows/scene_intake/prompts.py | 27 ++++++----- .../workflows/scene_intake/schema.py | 18 +++++++- .../workflows/scene_intake/utils.py | 29 +++++++++++- .../workflows/unified_scene/schema.py | 6 ++- .../workflows/unified_scene/utils.py | 7 ++- 10 files changed, 130 insertions(+), 21 deletions(-) diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/managers/table_clutter_fit_manager/manager.py b/embodichain/gen_sim/prompt2scene/agent_tools/managers/table_clutter_fit_manager/manager.py index 987e14878..3a9a86e5b 100644 --- a/embodichain/gen_sim/prompt2scene/agent_tools/managers/table_clutter_fit_manager/manager.py +++ b/embodichain/gen_sim/prompt2scene/agent_tools/managers/table_clutter_fit_manager/manager.py @@ -94,10 +94,19 @@ def fit_table_to_clutter( output_dir: Path, margin_cm: float = 10.0, support_occupancy_ratio: float = 0.80, + object_coverage_percent: int | None = None, gravity_settle_table: bool = True, sim_device: str = "cpu", ) -> dict[str, Any]: - """Fit a table mesh to an already laid-out clutter result.""" + """Fit a table mesh to an already laid-out clutter result. + + Args: + object_coverage_percent: If set (1-100), overrides + ``support_occupancy_ratio`` by converting the percentage to a ratio + (e.g. 30 → 0.30). The required table size is computed as + clutter_size / ratio. When None, the default + ``support_occupancy_ratio`` is used. + """ try: import trimesh except ImportError as exc: @@ -166,6 +175,10 @@ def fit_table_to_clutter( # Compute the required table size and uniform scale. clutter_size_cm = (clutter_bounds[1, :2] - clutter_bounds[0, :2]) * 100.0 + if object_coverage_percent is not None: + support_occupancy_ratio = float( + np.clip(object_coverage_percent / 100.0, 0.1, 1.0) + ) occupancy = float(np.clip(support_occupancy_ratio, 0.1, 1.0)) required_size_cm = clutter_size_cm / occupancy + 2.0 * float(margin_cm) support_size_cm = np.asarray(initial_support["size_xy"], dtype=np.float64) * 100.0 diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/tools/image_scene_asset_generation.py b/embodichain/gen_sim/prompt2scene/agent_tools/tools/image_scene_asset_generation.py index 5df5984a4..9d3e42f1d 100644 --- a/embodichain/gen_sim/prompt2scene/agent_tools/tools/image_scene_asset_generation.py +++ b/embodichain/gen_sim/prompt2scene/agent_tools/tools/image_scene_asset_generation.py @@ -456,6 +456,7 @@ def generate_image_scene_assets( "status", "is_complete_visible_table", "complete_table_description", + "object_coverage_percent", "table_asset_source", "support_normal_source", "image_path", diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/tools/table_fit_scene.py b/embodichain/gen_sim/prompt2scene/agent_tools/tools/table_fit_scene.py index ae96b3a39..273f15a65 100644 --- a/embodichain/gen_sim/prompt2scene/agent_tools/tools/table_fit_scene.py +++ b/embodichain/gen_sim/prompt2scene/agent_tools/tools/table_fit_scene.py @@ -42,6 +42,7 @@ def fit_text_scene_table( clutter_result=clutter_layout_result, output_root=output_root, output_dir=output_dir, + object_coverage_percent=table_result.get("object_coverage_percent"), ) log_info(f"text table fit completed status={result.get('status')}") return result @@ -94,6 +95,7 @@ def fit_image_scene_table( clutter_result=clutter_result, output_root=output_root, output_dir=output_dir, + object_coverage_percent=generated_table.get("object_coverage_percent"), ) log_info(f"image table fit completed status={result.get('status')}") return result diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/tools/text_asset_generation.py b/embodichain/gen_sim/prompt2scene/agent_tools/tools/text_asset_generation.py index b0d4a0f72..ada7ad789 100644 --- a/embodichain/gen_sim/prompt2scene/agent_tools/tools/text_asset_generation.py +++ b/embodichain/gen_sim/prompt2scene/agent_tools/tools/text_asset_generation.py @@ -283,6 +283,7 @@ def generate_text_table_asset( "is_complete_visible_table": bool( table_spec.get("is_complete_visible_table", False) ), + "object_coverage_percent": table_spec.get("object_coverage_percent"), "status": status, "image_path": image_path, "raw_geometry_path": raw_geometry_path, diff --git a/embodichain/gen_sim/prompt2scene/prompts/data/scene_intake.yaml b/embodichain/gen_sim/prompt2scene/prompts/data/scene_intake.yaml index cabf99cb5..bbdbbc8b0 100644 --- a/embodichain/gen_sim/prompt2scene/prompts/data/scene_intake.yaml +++ b/embodichain/gen_sim/prompt2scene/prompts/data/scene_intake.yaml @@ -19,6 +19,10 @@ text_system: | + - CRITICAL: Include EVERY visible object on the tabletop without omission. Do + not skip, ignore, or drop any object, no matter how small, blurry, partially + occluded, or unfamiliar it appears. An incomplete assets list is the most + severe error you can make. - Output only real physical objects that can become 3D asset generation targets. - Do not include the table or tabletop region in assets. - assets is a list of object category groups, not a list of individual object @@ -178,6 +182,10 @@ image_system: | + - CRITICAL: Include EVERY visible object on the tabletop without omission. Do + not skip, ignore, or drop any object, no matter how small, blurry, partially + occluded, or unfamiliar it appears. An incomplete assets list is the most + severe error you can make. - Output only real physical objects that can become 3D asset generation targets. - Do not include the table or tabletop region in assets. - assets is a list of object category groups, not a list of individual object @@ -293,6 +301,19 @@ image_system: | table.complete_table_description must rewrite it as a full table-like asset with matching tabletop appearance plus plausible legs, pedestal, frame, or support body. + - For image input with is_complete_visible_table=true ONLY: choose + table.object_coverage_percent from exactly one of these four values. + Think in terms of SPATIAL SPREAD, not pixel area: imagine drawing the + smallest rectangle that encloses ALL objects on the tabletop, then ask + what fraction of the table surface that rectangle covers. Even sparse + small objects can score high if they are spread across the whole table. + 10 (objects clustered in one small region, most of the table is bare), + 30 (objects spread across a noticeable portion but large bare areas remain), + 50 (objects reach roughly half the table extent in at least one direction), + 70 (objects span most of the table, even if gaps exist between them). + Do not output any other value. + - For text input, or when is_complete_visible_table=false: OMIT the + object_coverage_percent field entirely. Do not include it in the output. @@ -301,8 +322,9 @@ image_system: | "name": "table", "description": "A rectangular wooden table with a brown top and four straight legs.", "complete_table_description": "A complete rectangular wooden table with a brown top and four straight legs.", - "is_complete_visible_table": false, - "class_candidate": ["table", "dining_table", "desk", "wooden_table", "furniture"] + "is_complete_visible_table": true, + "class_candidate": ["table", "dining_table", "desk", "wooden_table", "furniture"], + "object_coverage_percent": 25 }, "assets": [ { @@ -374,6 +396,10 @@ verifier_system: | + - CRITICAL: Do NOT remove any asset row from the draft assets list. Your job is + to check and correct counts, names, and class_candidate values — not to drop + objects. If an object exists in the draft, it must remain in the corrected + output. Only add new rows if objects were clearly missed. - assets is a list of object category groups, not individual instances. - Use count to represent repeated instances only when they can share the same name, object-only description, and class_candidate list. @@ -419,6 +445,13 @@ verifier_system: | use the most conservative count supported by the image. - For text inputs, count only objects explicitly stated or strongly implied by the text. + - For image input with is_complete_visible_table=true: independently + re-assess the tabletop coverage against the original image and pick + table.object_coverage_percent from exactly one of 10, 30, 50, 70. + Correct the draft value if the bucket does not match the visible + clutter density. + - For text input or when is_complete_visible_table is false: remove + object_coverage_percent from table entirely if it is present in the draft. @@ -427,8 +460,9 @@ verifier_system: | "name": "table", "description": "A rectangular wooden table with a brown top and four straight legs.", "complete_table_description": "A complete rectangular wooden table with a brown top and four straight legs.", - "is_complete_visible_table": false, - "class_candidate": ["table", "dining_table", "desk", "wooden_table", "furniture"] + "is_complete_visible_table": true, + "class_candidate": ["table", "dining_table", "desk", "wooden_table", "furniture"], + "object_coverage_percent": 30 }, "assets": [ { @@ -444,7 +478,8 @@ verifier_system: | - The top-level object must contain only table and assets. - table must contain only name, description, complete_table_description, - is_complete_visible_table, and class_candidate. + is_complete_visible_table, class_candidate, and optionally + object_coverage_percent (only when is_complete_visible_table is true). - Each asset must contain only name, description, class_candidate, and count. - Output JSON only. Do not include markdown or explanations outside JSON. diff --git a/embodichain/gen_sim/prompt2scene/workflows/scene_intake/prompts.py b/embodichain/gen_sim/prompt2scene/workflows/scene_intake/prompts.py index 611c5bf95..421ec979b 100644 --- a/embodichain/gen_sim/prompt2scene/workflows/scene_intake/prompts.py +++ b/embodichain/gen_sim/prompt2scene/workflows/scene_intake/prompts.py @@ -50,19 +50,24 @@ def build_scene_intake_verifier_messages( scene_intake: SceneIntakeSpec, ) -> list[dict[str, Any]]: """Build messages for scene-intake group and count verification.""" + table_draft: dict[str, object] = { + "name": scene_intake.table.name, + "description": scene_intake.table.description, + "complete_table_description": ( + scene_intake.table.complete_table_description + ), + "is_complete_visible_table": ( + scene_intake.table.is_complete_visible_table + ), + "class_candidate": list(scene_intake.table.class_candidate), + } + if scene_intake.table.object_coverage_percent is not None: + table_draft["object_coverage_percent"] = ( + scene_intake.table.object_coverage_percent + ) scene_intake_json = json.dumps( { - "table": { - "name": scene_intake.table.name, - "description": scene_intake.table.description, - "complete_table_description": ( - scene_intake.table.complete_table_description - ), - "is_complete_visible_table": ( - scene_intake.table.is_complete_visible_table - ), - "class_candidate": list(scene_intake.table.class_candidate), - }, + "table": table_draft, "assets": [ { "name": asset.name, diff --git a/embodichain/gen_sim/prompt2scene/workflows/scene_intake/schema.py b/embodichain/gen_sim/prompt2scene/workflows/scene_intake/schema.py index 80c9ca27c..31b55e6df 100644 --- a/embodichain/gen_sim/prompt2scene/workflows/scene_intake/schema.py +++ b/embodichain/gen_sim/prompt2scene/workflows/scene_intake/schema.py @@ -92,6 +92,18 @@ "minLength": 1, }, }, + "object_coverage_percent": { + "type": "integer", + "enum": [10, 30, 50, 70], + "description": ( + "For image input with a complete visible table ONLY: " + "choose the closest coverage bucket for objects on the " + "tabletop: 10 (mostly empty, a few small objects), " + "30 (lightly cluttered), 50 (moderately cluttered), " + "70 (densely packed). Omit this field entirely for " + "text input or when is_complete_visible_table is false." + ), + }, }, "required": [ "name", @@ -193,10 +205,11 @@ class SceneIntakeTable: complete_table_description: str = "" is_complete_visible_table: bool = False class_candidate: list[str] = field(default_factory=list) + object_coverage_percent: int | None = None def to_manifest(self) -> dict[str, object]: """Convert the table record to JSON-safe data.""" - return { + manifest: dict[str, object] = { "id": self.id, "name": self.name, "description": self.description, @@ -204,6 +217,9 @@ def to_manifest(self) -> dict[str, object]: "is_complete_visible_table": self.is_complete_visible_table, "class_candidate": list(self.class_candidate), } + if self.object_coverage_percent is not None: + manifest["object_coverage_percent"] = self.object_coverage_percent + return manifest @dataclass(frozen=True) diff --git a/embodichain/gen_sim/prompt2scene/workflows/scene_intake/utils.py b/embodichain/gen_sim/prompt2scene/workflows/scene_intake/utils.py index e49fe9b3d..da084f559 100644 --- a/embodichain/gen_sim/prompt2scene/workflows/scene_intake/utils.py +++ b/embodichain/gen_sim/prompt2scene/workflows/scene_intake/utils.py @@ -19,6 +19,7 @@ import re from typing import Any +from embodichain.gen_sim.prompt2scene.utils.log import log_warning from embodichain.gen_sim.prompt2scene.workflows.request import Prompt2SceneInput from embodichain.gen_sim.prompt2scene.workflows.scene_intake.schema import ( SceneIntakeAsset, @@ -66,6 +67,7 @@ def _parse_table(raw_table: dict[str, Any]) -> SceneIntakeTable: "complete_table_description", "is_complete_visible_table", "class_candidate", + "object_coverage_percent", }, context="Scene intake table", ) @@ -107,12 +109,34 @@ def _parse_table(raw_table: dict[str, Any]) -> SceneIntakeTable: raw_name=name, ) + object_coverage_percent: int | None = None + raw_percent = raw_table.get("object_coverage_percent") + if raw_percent is not None: + if isinstance(raw_percent, bool): + raise ValueError( + "Scene intake table.object_coverage_percent must be an integer, " + "not a boolean." + ) + try: + object_coverage_percent = int(raw_percent) + except (TypeError, ValueError): + raise ValueError( + "Scene intake table.object_coverage_percent must be an integer " + f"between 1 and 100, got {raw_percent!r}." + ) + if object_coverage_percent not in (10, 30, 50, 70): + raise ValueError( + "Scene intake table.object_coverage_percent must be one of " + f"10, 30, 50, 70, got {object_coverage_percent}." + ) + return SceneIntakeTable( name=name, description=description, complete_table_description=complete_table_description, is_complete_visible_table=is_complete_visible_table, class_candidate=class_candidate, + object_coverage_percent=object_coverage_percent, ) @@ -214,7 +238,10 @@ def _validate_exact_keys( ) -> None: extra_keys = sorted(set(value) - allowed_keys) if extra_keys: - raise ValueError(f"{context} has unexpected keys: {extra_keys}.") + log_warning( + f"{context} has unexpected keys: {extra_keys}. " + f"These fields will be ignored." + ) def _require_mapping(value: Any, context: str) -> dict[str, Any]: diff --git a/embodichain/gen_sim/prompt2scene/workflows/unified_scene/schema.py b/embodichain/gen_sim/prompt2scene/workflows/unified_scene/schema.py index f3d13125d..baca2bebe 100644 --- a/embodichain/gen_sim/prompt2scene/workflows/unified_scene/schema.py +++ b/embodichain/gen_sim/prompt2scene/workflows/unified_scene/schema.py @@ -42,10 +42,11 @@ class UnifiedTable: image_path: str | None = None mesh_path: str | None = None grid_cells: dict[str, list[str]] | None = None + object_coverage_percent: int | None = None def to_manifest(self) -> dict[str, Any]: """Convert the table to JSON-safe data.""" - return { + manifest: dict[str, Any] = { "id": self.id, "name": self.name, "description": self.description, @@ -56,6 +57,9 @@ def to_manifest(self) -> dict[str, Any]: "mesh_path": self.mesh_path, "grid_cells": self.grid_cells, } + if self.object_coverage_percent is not None: + manifest["object_coverage_percent"] = self.object_coverage_percent + return manifest @dataclass(frozen=True) diff --git a/embodichain/gen_sim/prompt2scene/workflows/unified_scene/utils.py b/embodichain/gen_sim/prompt2scene/workflows/unified_scene/utils.py index e17b5e7b6..49e4a70cb 100644 --- a/embodichain/gen_sim/prompt2scene/workflows/unified_scene/utils.py +++ b/embodichain/gen_sim/prompt2scene/workflows/unified_scene/utils.py @@ -93,7 +93,7 @@ def build_unified_table( grid_cells: dict[str, list[str]] | None = None, ) -> dict[str, Any]: """Build a unified table record from scene intake.""" - return { + table: dict[str, Any] = { "id": scene_intake.table.id, "name": scene_intake.table.name, "description": scene_intake.table.description, @@ -106,6 +106,11 @@ def build_unified_table( "mesh_path": None, "grid_cells": grid_cells, } + if scene_intake.table.object_coverage_percent is not None: + table["object_coverage_percent"] = ( + scene_intake.table.object_coverage_percent + ) + return table def build_unified_spatial_anchor(anchor: ImageAnchor | None) -> dict[str, Any] | None: From 8293f2cff7c17420083ea5afa3adc8064934e38e Mon Sep 17 00:00:00 2001 From: Muzi Wong <178915912+MuziWong@users.noreply.github.com> Date: Tue, 30 Jun 2026 10:57:15 +0800 Subject: [PATCH 4/7] Fixed gym export bug; --- .../table_clutter_fit_manager/manager.py | 18 +- .../agent_tools/tools/gym_export.py | 315 +++++++++++++----- 2 files changed, 240 insertions(+), 93 deletions(-) diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/managers/table_clutter_fit_manager/manager.py b/embodichain/gen_sim/prompt2scene/agent_tools/managers/table_clutter_fit_manager/manager.py index 3a9a86e5b..eeb79a182 100644 --- a/embodichain/gen_sim/prompt2scene/agent_tools/managers/table_clutter_fit_manager/manager.py +++ b/embodichain/gen_sim/prompt2scene/agent_tools/managers/table_clutter_fit_manager/manager.py @@ -252,7 +252,23 @@ def fit_table_to_clutter( for oid, scene in shifted_clutter: object_path = output_dir / f"{oid}_on_table.glb" _copy_scene_with_transform(scene, z_to_y).export(object_path) - placed_objects.append({"id": oid, "path": str(object_path)}) + # Compute world-space AABB bottom-centre (sim Z-up coords) before + # the scene is converted to GLB Y-up for export. This is the + # reference position that gym_export uses to derive ``init_pos``. + _placed_mesh = _scene_to_mesh(scene, trimesh=trimesh) + _placed_b = np.asarray(_placed_mesh.bounds, dtype=np.float64) + world_aabb_bottom_center = [ + float(0.5 * (_placed_b[0, 0] + _placed_b[1, 0])), + float(0.5 * (_placed_b[0, 1] + _placed_b[1, 1])), + float(_placed_b[0, 2]), + ] + placed_objects.append( + { + "id": oid, + "path": str(object_path), + "world_aabb_bottom_center": world_aabb_bottom_center, + } + ) # Write the fit manifest. final_clutter_bounds = _table_fit_scene_union_bounds( diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/tools/gym_export.py b/embodichain/gen_sim/prompt2scene/agent_tools/tools/gym_export.py index 0dcd67180..d26a14842 100644 --- a/embodichain/gen_sim/prompt2scene/agent_tools/tools/gym_export.py +++ b/embodichain/gen_sim/prompt2scene/agent_tools/tools/gym_export.py @@ -20,6 +20,7 @@ import math import shutil import time +from collections.abc import Sequence from pathlib import Path from typing import Any @@ -49,7 +50,12 @@ "restitution": 0.01, } -_DEFAULT_MAX_CONVEX_HULL_NUM = 8 +_DEFAULT_MAX_CONVEX_HULL_NUM = 32 + + +# --------------------------------------------------------------------------- +# helpers +# --------------------------------------------------------------------------- def _resolve_path(value: str, output_root: Path) -> Path: @@ -83,13 +89,37 @@ def _matrix_to_euler_xyz_deg(matrix: list[list[float]]) -> list[float]: return [math.degrees(x), math.degrees(y), math.degrees(z)] -def _glb_aabb_bottom_center(glb_path: Path) -> list[float]: - """``[x, y, z]`` bottom-centre position in **simulation Z-up** space. +def _glb_to_sim_rotation() -> np.ndarray: + """Return the loader basis conversion from GLB Y-up to sim Z-up.""" + return np.array( + [ + [1.0, 0.0, 0.0], + [0.0, 0.0, -1.0], + [0.0, 1.0, 0.0], + ], + dtype=np.float64, + ) - The GLB is stored in Y-up convention (X=width, Y=up, Z=depth). - EmbodiChain simulation converts to Z-up on load, so we return the - position in Z-up space: ``center_X``, ``-center_Z``, ``min_Y``. - """ + +def _glb_rotation_to_sim(rotation_matrix: list[list[float]]) -> list[list[float]]: + """Convert a GLB-space local rotation into simulation-space rotation.""" + rot = np.asarray(rotation_matrix, dtype=np.float64) + if rot.shape == (4, 4): + rot = rot[:3, :3] + basis = _glb_to_sim_rotation() + return (basis @ rot @ basis.T).tolist() + + +def _glb_scale_to_sim(scale: Sequence[float]) -> list[float]: + """Convert GLB-axis scale components to sim-axis body_scale components.""" + values = [float(v) for v in scale] + if len(values) != 3: + raise ValueError("scale must have three components") + return [values[0], values[2], values[1]] + + +def _glb_max_z(glb_path: Path) -> float: + """Maximum height (Y in GLB, Z in simulation) of a mesh.""" import trimesh scene = trimesh.load(glb_path, force="scene") @@ -104,16 +134,23 @@ def _glb_aabb_bottom_center(glb_path: Path) -> list[float]: [m for m in dumped if isinstance(m, trimesh.Trimesh)] ) ) - b = np.asarray(mesh.bounds, dtype=np.float64) - return [ - float(0.5 * (b[0, 0] + b[1, 0])), # centre X - float(-0.5 * (b[0, 2] + b[1, 2])), # -centre Z (GLB Z → internal -Y) - float(b[0, 1]), # min Y (GLB up → internal Z) - ] + return float(np.asarray(mesh.bounds, dtype=np.float64)[1, 1]) # max Y -def _glb_max_z(glb_path: Path) -> float: - """Maximum height (Y in GLB, Z in simulation) of a mesh.""" +def _rotated_aabb_offsets( + glb_path: Path, + rotation_matrix: list[list[float]] | None, + scale: float | Sequence[float] = 1.0, +) -> tuple[float, float, float]: + """Compute the AABB shift caused by rotation + scale alone. + + Loads the simready GLB, applies *rotation_matrix* and *scale_factor* + around the local origin (the AABB bottom-centre), and returns the XY + centre and minimum Z of the resulting AABB. These offsets are + subtracted from the fitted AABB bottom-centre to recover the true + world-space position of the simready local origin (the ``init_pos`` + that the simulation expects). + """ import trimesh scene = trimesh.load(glb_path, force="scene") @@ -128,7 +165,122 @@ def _glb_max_z(glb_path: Path) -> float: [m for m in dumped if isinstance(m, trimesh.Trimesh)] ) ) - return float(np.asarray(mesh.bounds, dtype=np.float64)[1, 1]) # max Y + verts = mesh.vertices.copy() + if isinstance(scale, Sequence) and not isinstance(scale, (str, bytes)): + scale_array = np.asarray(list(scale), dtype=np.float64) + if scale_array.shape != (3,): + raise ValueError("scale must be a scalar or a 3-vector") + verts *= scale_array + else: + verts *= float(scale) + if rotation_matrix is not None: + rot = np.asarray(rotation_matrix, dtype=np.float64) + if rot.shape == (4, 4): + rot = rot[:3, :3] + verts = (rot @ verts.T).T + b = np.zeros((2, 3), dtype=np.float64) + b[0] = verts.min(axis=0) + b[1] = verts.max(axis=0) + return ( + float(0.5 * (b[0, 0] + b[1, 0])), # AABB centre X → sim X + float(-0.5 * (b[0, 2] + b[1, 2])), # -centre Z → sim Y + float(b[0, 1]), # min Y → sim Z + ) + + +# --------------------------------------------------------------------------- +# consolidated object manifest +# --------------------------------------------------------------------------- + + +def _build_object_manifest( + output_root: Path, + step_result: dict[str, Any], + table_fit_manifest: dict[str, Any], + aligned_by_id: dict[str, dict[str, Any]], +) -> dict[str, Any]: + """Merge world_bc, rotation, scale into one per-object record. + + Returns a dict keyed by object id, each value containing everything + needed to compute ``init_pos`` / ``init_rot`` / ``body_scale``. + """ + objects_info = step_result.get("objects") or [] + + # index metric_scale by object id + metric_by_id: dict[str, float] = {} + for obj in objects_info: + oid = str(obj.get("id", "")) + if not oid: + continue + ms = obj.get("metric_scale") + sf = float(ms.get("scale_factor", 1.0)) if isinstance(ms, dict) else 1.0 + metric_by_id[oid] = sf + + # index world_aabb_bottom_center from table-fit manifest + world_bc_by_id: dict[str, list[float]] = {} + for e in table_fit_manifest.get("objects") or []: + eid = str(e.get("id", "")) if isinstance(e, dict) else "" + wbc = e.get("world_aabb_bottom_center") if isinstance(e, dict) else None + if eid and isinstance(wbc, list) and len(wbc) == 3: + world_bc_by_id[eid] = [float(v) for v in wbc] + + consolidated: dict[str, Any] = {} + skipped_no_glb: list[str] = [] + for obj in objects_info: + oid = str(obj.get("id", "")) + if not oid: + continue + + source = obj.get("simready_geometry_path") or obj.get("mesh_path") + simready_path = _resolve_path(source or "", output_root) + if not simready_path.is_file(): + skipped_no_glb.append(oid) + continue + + description = str(obj.get("description") or obj.get("name") or "").strip() + scale_factor = metric_by_id.get(oid, 1.0) + + aligned = aligned_by_id.get(oid) + rot_matrix: list[list[float]] | None = None + transform_scale: list[float] | None = None + if aligned: + raw = aligned.get("rotation_matrix") + if raw and isinstance(raw, list): + rot_matrix = raw + raw_scale = aligned.get("scale") + if isinstance(raw_scale, list) and len(raw_scale) == 3: + transform_scale = [float(v) for v in raw_scale] + + wbc = world_bc_by_id.get(oid) + + consolidated[oid] = { + "id": oid, + "description": description, + "simready_path": simready_path, + "scale_factor": scale_factor, + "transform_scale": transform_scale, + "rotation_matrix": rot_matrix, + "world_aabb_bottom_center": wbc, + } + + if skipped_no_glb: + print( + " [WARN] object(s) skipped (simready GLB not found): " + + ", ".join(skipped_no_glb) + ) + extra_in_manifest = set(world_bc_by_id) - set(consolidated) + if extra_in_manifest: + print( + " [WARN] object(s) in table-fit manifest but not in step_result: " + + ", ".join(sorted(extra_in_manifest)) + ) + + return consolidated + + +# --------------------------------------------------------------------------- +# main export +# --------------------------------------------------------------------------- def export_gym_config( @@ -148,45 +300,33 @@ def export_gym_config( export_dir = export_dir.expanduser().resolve() export_dir.mkdir(parents=True, exist_ok=True) - # ── step result & table-fit manifest ────────────────────────────── + # ── data sources ──────────────────────────────────────────────────── step_result = _read_json( output_root / UNIFIED_SCENE_GEN_STEP / STEP_RESULT_FILENAME ) table_fit = step_result.get("table_fit_to_clutter") or {} - manifest = _read_json( + table_fit_manifest = _read_json( _resolve_path(table_fit.get("manifest_path", ""), output_root) ) - # ── per-object metadata from simready→aligned manifest ──────────── aligned_by_id: dict[str, dict[str, Any]] = {} aligned_manifest_path = ( - output_root / UNIFIED_SCENE_GEN_STEP / "glb_gen" / "simready_to_aligned_manifest.json" + output_root + / UNIFIED_SCENE_GEN_STEP + / "glb_gen" + / "simready_to_aligned_manifest.json" ) if aligned_manifest_path.is_file(): - aligned_manifest = _read_json(aligned_manifest_path) - for item in aligned_manifest.get("items", []) or []: - if isinstance(item, dict): - aligned_by_id[str(item.get("id", ""))] = item - - # ── table surface Z (from fitted table GLB) ─────────────────────── - fitted_table_path = _resolve_path( - manifest.get("table_output_path", ""), output_root - ) - table_surface_z = ( - _glb_max_z(fitted_table_path) if fitted_table_path.is_file() else 0.0 - ) + for item in _read_json(aligned_manifest_path).get("items", []) or []: + if isinstance(item, dict) and item.get("id"): + aligned_by_id[str(item["id"])] = item - # ── description lookup ──────────────────────────────────────────── - object_meta_by_id: dict[str, dict[str, str]] = {} - for obj in step_result.get("objects", []) or []: - if isinstance(obj, dict): - oid = str(obj.get("id", "")) - if oid: - object_meta_by_id[oid] = { - "description": str(obj.get("description") or "").strip(), - "name": str(obj.get("name") or "").strip(), - } + # ── consolidated per-object manifest ───────────────────────────────── + object_manifest = _build_object_manifest( + output_root, step_result, table_fit_manifest, aligned_by_id + ) + # ── table ──────────────────────────────────────────────────────────── table_info = step_result.get("table") or {} table_desc = str( table_info.get("complete_table_description") @@ -196,7 +336,6 @@ def export_gym_config( mesh_assets_dir = export_dir / "mesh_assets" mesh_assets_dir.mkdir(parents=True, exist_ok=True) - # ── table ───────────────────────────────────────────────────────── table_simready = _resolve_path( table_info.get("simready_geometry_path") or table_info.get("mesh_path", ""), @@ -208,67 +347,52 @@ def export_gym_config( table_dst.parent.mkdir(parents=True, exist_ok=True) shutil.copy2(table_simready, table_dst) + table_surface_z = _glb_max_z(table_simready) + uniform_scale = 1.0 - ts = manifest.get("table_xy_scale") + ts = table_fit_manifest.get("table_xy_scale") if isinstance(ts, dict): uniform_scale = float(ts.get("uniform_scale", 1.0)) - # ── objects ─────────────────────────────────────────────────────── - table_fit_objects = { - str(e["id"]): _resolve_path(e["path"], output_root) - for e in (manifest.get("objects") or []) - if isinstance(e, dict) - } - objects_info = step_result.get("objects") or [] + # ── objects ────────────────────────────────────────────────────────── rigid_objects: list[dict[str, Any]] = [] - def _obj_desc(obj: dict[str, Any]) -> str: - meta = object_meta_by_id.get(str(obj.get("id", ""))) - return (meta["description"] or meta["name"]) if meta else "" - - for obj in objects_info: - if not isinstance(obj, dict): - continue - object_id = str(obj.get("id", "")) - if not object_id: - continue - - # ── GLB: simready (normalised, no baked transforms) ────────── - source = obj.get("simready_geometry_path") or obj.get("mesh_path") - object_src = _resolve_path(source, output_root) - if not object_src.is_file(): - continue - - safe_name = object_id.replace("interact_", "").strip("_") or "object" - obj_dir = mesh_assets_dir / safe_name / object_id + total = len(object_manifest) + for idx, (oid, om) in enumerate(object_manifest.items()): + # Copy simready GLB + safe_name = oid.replace("interact_", "").strip("_") or "object" + obj_dir = mesh_assets_dir / safe_name / oid obj_dir.mkdir(parents=True, exist_ok=True) - object_dst = obj_dir / f"{object_id}.glb" - shutil.copy2(object_src, object_dst) + object_dst = obj_dir / f"{oid}.glb" + shutil.copy2(om["simready_path"], object_dst) - # ── body_scale ──────────────────────────────────────────────── - ms = obj.get("metric_scale") - scale_factor = float(ms.get("scale_factor", 1.0)) if isinstance(ms, dict) else 1.0 - body_scale = [scale_factor, scale_factor, scale_factor] + # body_scale. Image-scene alignment may contain a full simready→aligned + # scale; text-scene layout only has the per-object metric scale. + sf = om["scale_factor"] + scale_glb = om.get("transform_scale") or [sf, sf, sf] + body_scale = _glb_scale_to_sim(scale_glb) - # ── init_pos: read from fitted on-table GLB ─────────────────── - fitted_path = table_fit_objects.get(object_id) - if fitted_path and fitted_path.is_file(): - init_pos = _glb_aabb_bottom_center(fitted_path) - else: - init_pos = [0.0, 0.0, table_surface_z] - - # ── init_rot: decompose from simready→aligned rotation ──────── + # init_rot init_rot: list[float] = [0.0, 0.0, 0.0] - aligned = aligned_by_id.get(object_id) - if aligned: - rot = aligned.get("rotation_matrix") - if rot and isinstance(rot, list): - init_rot = _matrix_to_euler_xyz_deg(rot) + if om["rotation_matrix"] is not None: + init_rot = _matrix_to_euler_xyz_deg( + _glb_rotation_to_sim(om["rotation_matrix"]) + ) + + # init_pos = world_bc - rotated_aabb_offset + ro = _rotated_aabb_offsets( + om["simready_path"], om["rotation_matrix"], scale_glb + ) + wbc = om["world_aabb_bottom_center"] + if wbc is not None: + init_pos = [wbc[0] - ro[0], wbc[1] - ro[1], wbc[2] - ro[2]] + else: + init_pos = [-ro[0], -ro[1], table_surface_z - ro[2]] rigid_objects.append( { - "uid": object_id, - "description": _obj_desc(obj), + "uid": oid, + "description": om["description"], "shape": { "shape_type": "Mesh", "fpath": str(object_dst.relative_to(export_dir)), @@ -282,8 +406,14 @@ def _obj_desc(obj: dict[str, Any]) -> str: "max_convex_hull_num": _DEFAULT_MAX_CONVEX_HULL_NUM, } ) + wbc = om["world_aabb_bottom_center"] + wbc_flag = "wbc" if wbc is not None else "fallback" + print( + f" [{idx+1}/{total}] [{oid}] {om['description']}" + f" pos={init_pos} rot={init_rot} scale={body_scale} src={wbc_flag}" + ) - # ── write config ────────────────────────────────────────────────── + # ── write gym config ───────────────────────────────────────────────── config = { "id": f"Prompt2Scene-{int(time.time() * 1000)}-v0", "max_episodes": 10, @@ -316,4 +446,5 @@ def _obj_desc(obj: dict[str, Any]) -> str: json.dumps(config, indent=4, ensure_ascii=False) + "\n", encoding="utf-8", ) + return config_path From aafb63bf99839b44a76141e91e9a4c99a84983f7 Mon Sep 17 00:00:00 2001 From: Muzi Wong <178915912+MuziWong@users.noreply.github.com> Date: Wed, 1 Jul 2026 13:33:37 +0800 Subject: [PATCH 5/7] Cleaned the code; --- .../managers/geometry_manager/manager.py | 181 ++- .../{scene_geometry.py => utils.py} | 0 .../managers/image_scene_manager/prompts.py | 106 -- .../managers/image_scene_manager/schemas.py | 71 - .../__init__.py | 6 +- .../managers/layout_manager/manager.py | 76 + .../__init__.py => layout_manager/schemas.py} | 12 +- .../managers/layout_manager/utils.py | 1350 +++++++++++++++++ .../managers/metric_scale_manager/__init__.py | 37 - .../managers/metric_scale_manager/manager.py | 431 ------ .../managers/metric_scale_manager/schemas.py | 73 - .../managers/optimization_manager/__init__.py | 37 - .../managers/optimization_manager/manager.py | 633 -------- .../managers/simready_manager/__init__.py | 21 + .../managers/simready_manager/manager.py | 444 ++++-- .../managers/simready_manager/schemas.py | 57 +- .../managers/simready_manager/utils.py | 136 ++ .../managers/text_layout_manager/__init__.py | 33 - .../managers/text_layout_manager/layout.py | 383 ----- .../text_layout_manager/optimization.py | 404 ----- .../image_layout_alignment.py} | 74 +- .../tools/image_scene_asset_generation.py | 20 +- .../agent_tools/tools/image_segment_filter.py | 189 +++ .../layout_manifests.py} | 12 +- .../tools/spatial_relations.py} | 0 .../manager.py => tools/table_clutter_fit.py} | 104 +- .../agent_tools/tools/table_fit_scene.py | 57 +- .../agent_tools/tools/text_clutter_layout.py | 2 +- .../settle.py => tools/text_object_settle.py} | 93 +- .../tools/text_scene_metric_scale.py | 2 +- .../{workflows => llms}/llm_output.py | 22 - .../gen_sim/prompt2scene/pipeline/runner.py | 27 +- .../gen_sim/prompt2scene/prompts/__init__.py | 7 +- .../gen_sim/prompt2scene/prompts/builders.py | 394 +++++ .../prompts/data/text_relations.yaml | 31 +- .../gen_sim/prompt2scene/prompts/schemas.py | 354 +++++ .../prompt2scene/workflows/artifact_writer.py | 114 +- .../tools => workflows}/gym_export.py | 172 +-- .../workflows/image_relations/nodes.py | 126 +- .../workflows/image_relations/prompts.py | 113 -- .../workflows/image_relations/schema.py | 81 - .../workflows/image_relations/utils.py | 173 +-- .../gen_sim/prompt2scene/workflows/paths.py | 219 +++ .../workflows/scene_intake/nodes.py | 24 +- .../workflows/scene_intake/prompts.py | 202 --- .../workflows/scene_intake/schema.py | 137 -- .../workflows/text_relations/nodes.py | 16 +- .../workflows/text_relations/prompts.py | 55 - .../workflows/text_relations/schema.py | 61 - .../workflows/text_relations/utils.py | 2 +- .../workflows/unified_scene/utils.py | 2 +- .../workflows/unified_scene_gen/nodes.py | 26 +- .../workflows/unified_scene_gen/paths.py | 102 -- .../workflows/unified_scene_gen/prompts.py | 141 -- .../workflows/unified_scene_gen/schema.py | 71 - .../{scene_update.py => utils.py} | 0 56 files changed, 3816 insertions(+), 3900 deletions(-) rename embodichain/gen_sim/prompt2scene/agent_tools/managers/geometry_manager/{scene_geometry.py => utils.py} (100%) delete mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/managers/image_scene_manager/prompts.py delete mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/managers/image_scene_manager/schemas.py rename embodichain/gen_sim/prompt2scene/agent_tools/managers/{table_clutter_fit_manager => layout_manager}/__init__.py (83%) create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/managers/layout_manager/manager.py rename embodichain/gen_sim/prompt2scene/agent_tools/managers/{image_scene_manager/__init__.py => layout_manager/schemas.py} (67%) create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/managers/layout_manager/utils.py delete mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/managers/metric_scale_manager/__init__.py delete mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/managers/metric_scale_manager/manager.py delete mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/managers/metric_scale_manager/schemas.py delete mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/managers/optimization_manager/__init__.py delete mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/managers/optimization_manager/manager.py create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/managers/simready_manager/utils.py delete mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/managers/text_layout_manager/__init__.py delete mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/managers/text_layout_manager/layout.py delete mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/managers/text_layout_manager/optimization.py rename embodichain/gen_sim/prompt2scene/agent_tools/{managers/image_scene_manager/alignment.py => tools/image_layout_alignment.py} (89%) create mode 100644 embodichain/gen_sim/prompt2scene/agent_tools/tools/image_segment_filter.py rename embodichain/gen_sim/prompt2scene/agent_tools/{managers/image_scene_manager/manifests.py => tools/layout_manifests.py} (95%) rename embodichain/gen_sim/prompt2scene/{workflows/spatial.py => agent_tools/tools/spatial_relations.py} (100%) rename embodichain/gen_sim/prompt2scene/agent_tools/{managers/table_clutter_fit_manager/manager.py => tools/table_clutter_fit.py} (80%) rename embodichain/gen_sim/prompt2scene/agent_tools/{managers/text_layout_manager/settle.py => tools/text_object_settle.py} (84%) rename embodichain/gen_sim/prompt2scene/{workflows => llms}/llm_output.py (92%) create mode 100644 embodichain/gen_sim/prompt2scene/prompts/builders.py create mode 100644 embodichain/gen_sim/prompt2scene/prompts/schemas.py rename embodichain/gen_sim/prompt2scene/{agent_tools/tools => workflows}/gym_export.py (74%) delete mode 100644 embodichain/gen_sim/prompt2scene/workflows/image_relations/prompts.py create mode 100644 embodichain/gen_sim/prompt2scene/workflows/paths.py delete mode 100644 embodichain/gen_sim/prompt2scene/workflows/scene_intake/prompts.py delete mode 100644 embodichain/gen_sim/prompt2scene/workflows/text_relations/prompts.py delete mode 100644 embodichain/gen_sim/prompt2scene/workflows/unified_scene_gen/paths.py delete mode 100644 embodichain/gen_sim/prompt2scene/workflows/unified_scene_gen/prompts.py delete mode 100644 embodichain/gen_sim/prompt2scene/workflows/unified_scene_gen/schema.py rename embodichain/gen_sim/prompt2scene/workflows/unified_scene_gen/{scene_update.py => utils.py} (100%) diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/managers/geometry_manager/manager.py b/embodichain/gen_sim/prompt2scene/agent_tools/managers/geometry_manager/manager.py index 2e5c88ab3..fa42ead4a 100644 --- a/embodichain/gen_sim/prompt2scene/agent_tools/managers/geometry_manager/manager.py +++ b/embodichain/gen_sim/prompt2scene/agent_tools/managers/geometry_manager/manager.py @@ -60,6 +60,172 @@ class GeometryManager: the same pattern as service clients. """ + @staticmethod + def compose_json_matrices(*values: Any) -> list[list[float]]: + from . import utils as geometry_utils + + return geometry_utils._compose_json_matrices(*values) + + @staticmethod + def compose_simready_to_aligned_matrix( + *, raw_to_aligned_matrix: Any, raw_to_simready_matrix: Any + ) -> list[list[float]]: + from . import utils as geometry_utils + + return geometry_utils._compose_simready_to_aligned_matrix( + raw_to_aligned_matrix=raw_to_aligned_matrix, + raw_to_simready_matrix=raw_to_simready_matrix, + ) + + @staticmethod + def decompose_transform_matrix(matrix_value: Any) -> dict[str, Any]: + from . import utils as geometry_utils + + return geometry_utils._decompose_transform_matrix(matrix_value) + + @staticmethod + def support_normal_flip_transform(**kwargs: Any) -> Any: + from . import utils as geometry_utils + + return geometry_utils._support_normal_flip_transform(**kwargs) + + @staticmethod + def z_yaw_transform(yaw_degrees: float) -> Any: + from . import utils as geometry_utils + + return geometry_utils._z_yaw_transform(yaw_degrees) + + @staticmethod + def z_up_to_glb_y_up_transform() -> Any: + from . import utils as geometry_utils + + return geometry_utils._z_up_to_glb_y_up_transform() + + @staticmethod + def copy_scene_with_transform(scene: Any, transform: Any) -> Any: + from . import utils as geometry_utils + + return geometry_utils._copy_scene_with_transform(scene, transform) + + @staticmethod + def matrix_from_json(value: Any, *, name: str) -> Any: + from . import utils as geometry_utils + + return geometry_utils._matrix_from_json(value, name=name) + + @staticmethod + def load_scene_with_transform(**kwargs: Any) -> Any: + from . import utils as geometry_utils + + return geometry_utils._load_scene_with_transform(**kwargs) + + @staticmethod + def estimate_support_normal(mesh: Any) -> Any: + from . import utils as geometry_utils + + return geometry_utils._estimate_support_normal(mesh) + + @staticmethod + def rotation_between_vectors(source: Any, target: Any) -> Any: + from . import utils as geometry_utils + + return geometry_utils._rotation_between_vectors(source, target) + + @staticmethod + def transform_point(transform: Any, point: Any) -> Any: + from . import utils as geometry_utils + + return geometry_utils._transform_point(transform, point) + + @staticmethod + def aabb_center(bounds: Any) -> Any: + from . import utils as geometry_utils + + return geometry_utils._aabb_center(bounds) + + @staticmethod + def xy_aabb_center(bounds: Any) -> Any: + from . import utils as geometry_utils + + return geometry_utils._xy_aabb_center(bounds) + + @staticmethod + def xy_aabb_size(bounds: Any) -> Any: + from . import utils as geometry_utils + + return geometry_utils._xy_aabb_size(bounds) + + @staticmethod + def aabb_bottom_to_xy_plane_transform(bounds: Any) -> Any: + from . import utils as geometry_utils + + return geometry_utils._aabb_bottom_to_xy_plane_transform(bounds) + + @staticmethod + def scale_transform(scale: float) -> Any: + from . import utils as geometry_utils + + return geometry_utils._scale_transform(scale) + + @staticmethod + def compose_sam3d_multi_object_transform(**kwargs: Any) -> Any: + from . import utils as geometry_utils + + return geometry_utils._compose_sam3d_multi_object_transform(**kwargs) + + @staticmethod + def detect_table_fit_support_quad( + mesh: Any, + *, + target_aspect: float, + ) -> dict[str, Any]: + from . import utils as geometry_utils + + return geometry_utils._detect_table_fit_support_quad( + mesh, + target_aspect=target_aspect, + ) + + @staticmethod + def load_table_fit_scene_internal_z(path: Path, *, trimesh: Any, y_to_z: Any) -> Any: + from . import utils as geometry_utils + + return geometry_utils._load_table_fit_scene_internal_z( + path, + trimesh=trimesh, + y_to_z=y_to_z, + ) + + @staticmethod + def table_fit_scene_union_bounds(scenes: list[Any], *, trimesh: Any) -> Any: + from . import utils as geometry_utils + + return geometry_utils._table_fit_scene_union_bounds(scenes, trimesh=trimesh) + + @staticmethod + def table_fit_bounds_xy_manifest( + bounds: Any, + *, + unit_scale: float, + ) -> dict[str, Any]: + from . import utils as geometry_utils + + return geometry_utils._table_fit_bounds_xy_manifest( + bounds, + unit_scale=unit_scale, + ) + + @staticmethod + def table_fit_uniform_xy_scale_transform(**kwargs: Any) -> Any: + from . import utils as geometry_utils + + return geometry_utils._table_fit_uniform_xy_scale_transform(**kwargs) + + @staticmethod + def table_fit_safe_positive_ratio(numerator: float, denominator: float) -> float: + from . import utils as geometry_utils + + return geometry_utils._table_fit_safe_positive_ratio(numerator, denominator) @staticmethod def load_mesh(request: LoadMeshRequest) -> LoadMeshResult: @@ -228,17 +394,22 @@ def best_axis_bbox_scale_match( return best @staticmethod - def scene_to_mesh(scene: Any) -> Any: + def scene_to_mesh(scene: Any, *, trimesh: Any | None = None) -> Any: """Convert a trimesh Scene or mesh-like object to one mesh.""" - if isinstance(scene, trimesh.Trimesh): + trimesh_module = globals()["trimesh"] + if trimesh is not None: + trimesh_module = trimesh + if isinstance(scene, trimesh_module.Trimesh): return scene dumped = scene.dump(concatenate=True) - if isinstance(dumped, trimesh.Trimesh): + if isinstance(dumped, trimesh_module.Trimesh): return dumped - meshes = [item for item in dumped if isinstance(item, trimesh.Trimesh)] + meshes = [ + item for item in dumped if isinstance(item, trimesh_module.Trimesh) + ] if not meshes: raise ValueError("Scene contains no mesh geometry.") - return trimesh.util.concatenate(meshes) + return trimesh_module.util.concatenate(meshes) @staticmethod def detect_tabletop( diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/managers/geometry_manager/scene_geometry.py b/embodichain/gen_sim/prompt2scene/agent_tools/managers/geometry_manager/utils.py similarity index 100% rename from embodichain/gen_sim/prompt2scene/agent_tools/managers/geometry_manager/scene_geometry.py rename to embodichain/gen_sim/prompt2scene/agent_tools/managers/geometry_manager/utils.py diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/managers/image_scene_manager/prompts.py b/embodichain/gen_sim/prompt2scene/agent_tools/managers/image_scene_manager/prompts.py deleted file mode 100644 index 85b41388b..000000000 --- a/embodichain/gen_sim/prompt2scene/agent_tools/managers/image_scene_manager/prompts.py +++ /dev/null @@ -1,106 +0,0 @@ -# ---------------------------------------------------------------------------- -# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ---------------------------------------------------------------------------- - -from __future__ import annotations - -import json -from pathlib import Path -from typing import Any - -from embodichain.gen_sim.prompt2scene.prompts import render_prompt -from embodichain.gen_sim.prompt2scene.utils.io import image_to_data_url - -__all__ = [ - "build_image_metric_scale_messages", - "build_up_down_flip_check_messages", -] - -UNIFIED_SCENE_GEN_PROMPT_NAME = "unified_scene_gen.yaml" - - -def build_image_metric_scale_messages( - *, - bbox_name_image_path: Path, - objects_json: list[dict[str, Any]], -) -> list[dict[str, Any]]: - return [ - { - "role": "system", - "content": render_prompt( - UNIFIED_SCENE_GEN_PROMPT_NAME, - prompt_key="image_metric_scale_system", - ), - }, - { - "role": "user", - "content": [ - { - "type": "text", - "text": render_prompt( - UNIFIED_SCENE_GEN_PROMPT_NAME, - { - "objects_json": json.dumps( - objects_json, - ensure_ascii=False, - indent=2, - ), - }, - prompt_key="image_metric_scale_user", - ), - }, - { - "type": "image_url", - "image_url": {"url": image_to_data_url(bbox_name_image_path)}, - }, - ], - }, - ] - - -def build_up_down_flip_check_messages( - *, - original_image_path: Path, - comparison_image_path: Path, -) -> list[dict[str, Any]]: - return [ - { - "role": "system", - "content": render_prompt( - UNIFIED_SCENE_GEN_PROMPT_NAME, - prompt_key="up_down_flip_check_system", - ), - }, - { - "role": "user", - "content": [ - { - "type": "text", - "text": render_prompt( - UNIFIED_SCENE_GEN_PROMPT_NAME, - prompt_key="up_down_flip_check_user", - ), - }, - { - "type": "image_url", - "image_url": {"url": image_to_data_url(original_image_path)}, - }, - { - "type": "image_url", - "image_url": {"url": image_to_data_url(comparison_image_path)}, - }, - ], - }, - ] diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/managers/image_scene_manager/schemas.py b/embodichain/gen_sim/prompt2scene/agent_tools/managers/image_scene_manager/schemas.py deleted file mode 100644 index b22fcebba..000000000 --- a/embodichain/gen_sim/prompt2scene/agent_tools/managers/image_scene_manager/schemas.py +++ /dev/null @@ -1,71 +0,0 @@ -# ---------------------------------------------------------------------------- -# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ---------------------------------------------------------------------------- - -from __future__ import annotations - -from typing import Any - -__all__ = [ - "IMAGE_METRIC_SCALE_JSON_SCHEMA", - "UP_DOWN_FLIP_CHECK_JSON_SCHEMA", -] - -UP_DOWN_FLIP_CHECK_JSON_SCHEMA: dict[str, Any] = { - "title": "AlignedUpDownFlipCheckOutput", - "type": "object", - "additionalProperties": False, - "properties": { - "selected_number": {"type": "integer", "enum": [1, 2]}, - "confidence": {"type": "number", "minimum": 0.0, "maximum": 1.0}, - "reason": {"type": "string"}, - }, - "required": ["selected_number", "confidence", "reason"], -} - -IMAGE_METRIC_SCALE_JSON_SCHEMA: dict[str, Any] = { - "title": "ImageMetricScaleEstimate", - "type": "object", - "additionalProperties": False, - "properties": { - "object_scales": { - "type": "array", - "items": { - "type": "object", - "additionalProperties": False, - "properties": { - "object_id": {"type": "string"}, - "bbox_dims_cm": { - "type": "array", - "minItems": 3, - "maxItems": 3, - "items": { - "type": "number", - "minimum": 1.0e-6, - }, - }, - "confidence": { - "type": "number", - "minimum": 0.0, - "maximum": 1.0, - }, - "reason": {"type": "string"}, - }, - "required": ["object_id", "bbox_dims_cm", "confidence", "reason"], - }, - }, - }, - "required": ["object_scales"], -} diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/managers/table_clutter_fit_manager/__init__.py b/embodichain/gen_sim/prompt2scene/agent_tools/managers/layout_manager/__init__.py similarity index 83% rename from embodichain/gen_sim/prompt2scene/agent_tools/managers/table_clutter_fit_manager/__init__.py rename to embodichain/gen_sim/prompt2scene/agent_tools/managers/layout_manager/__init__.py index 0819a0d37..8d0684d85 100644 --- a/embodichain/gen_sim/prompt2scene/agent_tools/managers/table_clutter_fit_manager/__init__.py +++ b/embodichain/gen_sim/prompt2scene/agent_tools/managers/layout_manager/__init__.py @@ -16,8 +16,8 @@ from __future__ import annotations -from embodichain.gen_sim.prompt2scene.agent_tools.managers.table_clutter_fit_manager.manager import ( - fit_table_to_clutter, +from embodichain.gen_sim.prompt2scene.agent_tools.managers.layout_manager.manager import ( + LayoutManager, ) -__all__ = ["fit_table_to_clutter"] +__all__ = ["LayoutManager"] diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/managers/layout_manager/manager.py b/embodichain/gen_sim/prompt2scene/agent_tools/managers/layout_manager/manager.py new file mode 100644 index 000000000..a36e37506 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/managers/layout_manager/manager.py @@ -0,0 +1,76 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from typing import Any + +from .utils import ( + _center_xy_aabb_layout, + _footprint_layout_diagnostics, + _layout_text_objects_grid, + _object_scenes_xy_aabb_manifest, + _optimize_text_layout_slp, + _settle_and_pack_object_footprints, + _xy_aabb_overlap, + _xy_union_area, + _xy_union_bounds, +) + + +class LayoutManager: + """Public API for layout planning and footprint analysis. + + Tools should compose these methods instead of importing private helpers from + ``layout_manager.utils`` directly. The utils module remains an internal + implementation detail for shared math and optimization routines. + """ + + @staticmethod + def center_xy_aabb_layout(**kwargs: Any) -> Any: + return _center_xy_aabb_layout(**kwargs) + + @staticmethod + def footprint_layout_diagnostics(**kwargs: Any) -> Any: + return _footprint_layout_diagnostics(**kwargs) + + @staticmethod + def layout_text_objects_grid(**kwargs: Any) -> Any: + return _layout_text_objects_grid(**kwargs) + + @staticmethod + def object_scenes_xy_aabb_manifest(**kwargs: Any) -> Any: + return _object_scenes_xy_aabb_manifest(**kwargs) + + @staticmethod + def optimize_text_layout_slp(**kwargs: Any) -> Any: + return _optimize_text_layout_slp(**kwargs) + + @staticmethod + def settle_and_pack_object_footprints(**kwargs: Any) -> Any: + return _settle_and_pack_object_footprints(**kwargs) + + @staticmethod + def xy_aabb_overlap(**kwargs: Any) -> Any: + return _xy_aabb_overlap(**kwargs) + + @staticmethod + def xy_union_area(bounds: Any) -> float: + return _xy_union_area(bounds) + + @staticmethod + def xy_union_bounds(**kwargs: Any) -> Any: + return _xy_union_bounds(**kwargs) diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/managers/image_scene_manager/__init__.py b/embodichain/gen_sim/prompt2scene/agent_tools/managers/layout_manager/schemas.py similarity index 67% rename from embodichain/gen_sim/prompt2scene/agent_tools/managers/image_scene_manager/__init__.py rename to embodichain/gen_sim/prompt2scene/agent_tools/managers/layout_manager/schemas.py index 2ad8f11a5..015c41510 100644 --- a/embodichain/gen_sim/prompt2scene/agent_tools/managers/image_scene_manager/__init__.py +++ b/embodichain/gen_sim/prompt2scene/agent_tools/managers/layout_manager/schemas.py @@ -16,14 +16,4 @@ from __future__ import annotations -from embodichain.gen_sim.prompt2scene.agent_tools.managers.image_scene_manager.alignment import ( - _export_support_aligned_layout_glbs, -) -from embodichain.gen_sim.prompt2scene.agent_tools.managers.image_scene_manager.manifests import ( - _write_multi_object_layout_manifests, -) - -__all__ = [ - "_export_support_aligned_layout_glbs", - "_write_multi_object_layout_manifests", -] +__all__: list[str] = [] diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/managers/layout_manager/utils.py b/embodichain/gen_sim/prompt2scene/agent_tools/managers/layout_manager/utils.py new file mode 100644 index 000000000..a4b2dde39 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/managers/layout_manager/utils.py @@ -0,0 +1,1350 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. + + +from __future__ import annotations + +import tempfile +import traceback +from pathlib import Path +from typing import Any + +import numpy as np + +from embodichain.gen_sim.prompt2scene.agent_tools.managers.simulation_manager import ( + SimulationManager, +) +from embodichain.gen_sim.prompt2scene.agent_tools.managers.simulation_manager.schemas import ( + GravityDropRequest, +) +from embodichain.gen_sim.prompt2scene.agent_tools.managers.geometry_manager import ( + GeometryManager, +) +from embodichain.gen_sim.prompt2scene.utils.io import ( + relative_path, +) + +__all__ = [ + "_center_xy_aabb_layout", + "_object_scenes_xy_aabb_manifest", + "_settle_and_pack_object_footprints", + "_xy_aabb_overlap", + "_xy_union_area", + "_xy_union_bounds", +] + +def _object_scenes_xy_aabb_manifest( + *, + object_scenes: list[tuple[str, Any]], + trimesh: Any, + unit_scale: float, + unit: str, +) -> dict[str, Any]: + if not object_scenes: + return { + "status": "empty", + "unit": unit, + "object_count": 0, + } + bounds = [ + np.asarray( + GeometryManager.scene_to_mesh(scene, trimesh=trimesh).bounds, + dtype=np.float64, + ) + for _, scene in object_scenes + ] + union_bounds = np.vstack( + [ + np.vstack([item[0] for item in bounds]).min(axis=0), + np.vstack([item[1] for item in bounds]).max(axis=0), + ] + ) + min_xy = union_bounds[0, :2] * unit_scale + max_xy = union_bounds[1, :2] * unit_scale + size_xy = max_xy - min_xy + center_xy = 0.5 * (min_xy + max_xy) + return { + "status": "ok", + "unit": unit, + "object_count": len(object_scenes), + "min_xy": min_xy.tolist(), + "max_xy": max_xy.tolist(), + "center_xy": center_xy.tolist(), + "size_xy": size_xy.tolist(), + "area": float(size_xy[0] * size_xy[1]), + } + + + +def _settle_and_pack_object_footprints( + *, + object_scenes: list[tuple[str, Any]], + output_dir: Path, + output_root: Path, + trimesh: Any, +) -> dict[str, Any]: + sim = SimulationManager(headless=True, sim_device="cpu") + footprint_items: list[dict[str, Any]] = [] + settled_entries: list[dict[str, Any]] = [] + output_axis_transform = GeometryManager.z_up_to_glb_y_up_transform() + output_to_internal_transform = np.linalg.inv(output_axis_transform) + + with tempfile.TemporaryDirectory(prefix="p2s_footprint_drop_") as tmp_dir: + tmp_path = Path(tmp_dir) + for object_id, scene in object_scenes: + mesh = GeometryManager.scene_to_mesh(scene, trimesh=trimesh) + mesh_bounds = np.asarray(mesh.bounds, dtype=np.float64) + mesh_z_height = max(float(mesh_bounds[1][2] - mesh_bounds[0][2]), 0.0) + bottom_to_xy_plane_transform = GeometryManager.aabb_bottom_to_xy_plane_transform( + mesh_bounds + ) + normalized_scene = GeometryManager.copy_scene_with_transform( + scene, + bottom_to_xy_plane_transform, + ) + normalized_output_scene = GeometryManager.copy_scene_with_transform( + normalized_scene, + output_axis_transform, + ) + pre_gravity_path = tmp_path / f"{object_id}_pre_gravity.glb" + normalized_output_scene.export(pre_gravity_path) + gravity_initial_height = mesh_z_height * 0.1 + + gravity_status = "ok" + gravity_transform = np.eye(4, dtype=np.float64) + gravity_reason = "" + try: + gravity_result = sim.run_gravity_simulation( + GravityDropRequest( + glb_path=pre_gravity_path, + max_convex_hull_num=32, + initial_height=gravity_initial_height, + ) + ) + gravity_transform = GeometryManager.matrix_from_json( + gravity_result.final_pose, + name=f"{object_id}.gravity_final_pose", + ) + except Exception: + gravity_status = "failed" + gravity_reason = traceback.format_exc() + + settled_origin_scene = GeometryManager.copy_scene_with_transform( + normalized_scene, + gravity_transform, + ) + settled_mesh = GeometryManager.scene_to_mesh( + settled_origin_scene, + trimesh=trimesh, + ) + settled_bounds = np.asarray(settled_mesh.bounds, dtype=np.float64) + settled_xy_center = GeometryManager.xy_aabb_center(settled_bounds) + settled_xy_size = GeometryManager.xy_aabb_size(settled_bounds) + settled_entries.append( + { + "id": object_id, + "scene": scene, + "bottom_to_xy_plane_transform": bottom_to_xy_plane_transform, + "mesh_z_height": mesh_z_height, + "gravity_initial_height": gravity_initial_height, + "gravity_transform": gravity_transform, + "settled_bounds": settled_bounds, + "settled_xy_center": settled_xy_center, + "settled_xy_size": settled_xy_size, + "gravity_status": gravity_status, + "gravity_reason": gravity_reason, + } + ) + + layout_result = _optimize_xy_aabb_footprint_layout( + object_ids=[str(entry["id"]) for entry in settled_entries], + xy_sizes={ + str(entry["id"]): np.asarray(entry["settled_xy_size"], dtype=np.float64) + for entry in settled_entries + }, + current_centers={ + str(entry["id"]): GeometryManager.xy_aabb_center( + GeometryManager.scene_to_mesh( + entry["scene"], + trimesh=trimesh, + ).bounds + ) + for entry in settled_entries + }, + ) + target_centers = layout_result["centers"] + + packed_object_scenes: list[tuple[str, Any]] = [] + object_layout_transforms: dict[str, np.ndarray] = {} + for entry in settled_entries: + object_id = str(entry["id"]) + settled_bounds = np.asarray(entry["settled_bounds"], dtype=np.float64) + target_xy = target_centers[object_id] + placement_transform = np.eye(4, dtype=np.float64) + placement_transform[:3, 3] = [ + float(target_xy[0] - entry["settled_xy_center"][0]), + float(target_xy[1] - entry["settled_xy_center"][1]), + -float(settled_bounds[0][2]), + ] + object_transform = ( + placement_transform + @ entry["gravity_transform"] + @ entry["bottom_to_xy_plane_transform"] + ) + packed_scene = GeometryManager.copy_scene_with_transform( + entry["scene"], + object_transform, + ) + packed_object_scenes.append((object_id, packed_scene)) + object_layout_transforms[object_id] = object_transform + + packed_bounds = np.asarray( + GeometryManager.scene_to_mesh(packed_scene, trimesh=trimesh).bounds, + dtype=np.float64, + ) + footprint_items.append( + { + "id": object_id, + "gravity_status": entry["gravity_status"], + "gravity_reason": entry["gravity_reason"], + "bottom_to_xy_plane_transform": entry[ + "bottom_to_xy_plane_transform" + ].tolist(), + "mesh_z_height": entry["mesh_z_height"], + "gravity_initial_height": entry["gravity_initial_height"], + "gravity_transform": entry["gravity_transform"].tolist(), + "placement_transform": placement_transform.tolist(), + "object_layout_transform": object_transform.tolist(), + "settled_xy_size": entry["settled_xy_size"].tolist(), + "target_xy_center": target_xy.tolist(), + "packed_bounds": packed_bounds.tolist(), + } + ) + + manifest = { + "status": "ok", + "method": "per_object_gravity_then_geometry_knn_2d_aabb_relaxation", + "output_dir": relative_path(str(output_dir), output_root), + "internal_up_axis": [0.0, 0.0, 1.0], + "gravity_glb_up_axis": [0.0, 1.0, 0.0], + "internal_to_gravity_glb_transform": output_axis_transform.tolist(), + "gravity_glb_to_internal_transform": output_to_internal_transform.tolist(), + "layout_optimization": layout_result["metadata"], + "items": footprint_items, + } + return { + "object_scenes": packed_object_scenes, + "object_layout_transforms": object_layout_transforms, + "manifest": manifest, + } + + + +def _optimize_xy_aabb_footprint_layout( + *, + object_ids: list[str], + xy_sizes: dict[str, np.ndarray], + current_centers: dict[str, np.ndarray], + padding_ratio: float = 0.08, +) -> dict[str, Any]: + if not object_ids: + return { + "centers": {}, + "metadata": { + "method": "geometry_knn_2d_aabb_relaxation", + "iterations": 0, + "confidence_score": 1.0, + }, + } + + max_extent = max( + float(max(xy_sizes[object_id][0], xy_sizes[object_id][1])) + for object_id in object_ids + ) + padding = max(max_extent * padding_ratio, 1e-3) + max_iterations = 300 + overlap_strength = 1.0 + neighbor_strength = 0.04 + compactness_strength = 0.01 + target_expansion_ratio = 1.2 + knn_k = min(3, max(len(object_ids) - 1, 0)) + centers = { + object_id: np.asarray( + current_centers.get(object_id, np.zeros(2, dtype=np.float64)), + dtype=np.float64, + ).copy() + for object_id in object_ids + } + centers = _center_xy_aabb_layout( + centers=centers, + xy_sizes=xy_sizes, + ) + initial_centers = { + object_id: center.copy() + for object_id, center in centers.items() + } + initial_union_bounds = _xy_union_bounds( + centers=initial_centers, + xy_sizes=xy_sizes, + ) + neighbor_edges = _knn_neighbor_edges( + centers=initial_centers, + k=knn_k, + ) + + iterations = 0 + for iteration in range(max_iterations): + iterations = iteration + 1 + max_delta = 0.0 + + for i, object_id in enumerate(object_ids): + for other_id in object_ids[i + 1 :]: + overlap = _xy_aabb_overlap( + center_a=centers[object_id], + size_a=xy_sizes[object_id], + center_b=centers[other_id], + size_b=xy_sizes[other_id], + padding=padding, + ) + if overlap is None: + continue + overlap_x, overlap_y = overlap + if overlap_x <= overlap_y: + axis = 0 + sign = ( + -1.0 + if centers[object_id][0] <= centers[other_id][0] + else 1.0 + ) + amount = overlap_x + else: + axis = 1 + sign = ( + -1.0 + if centers[object_id][1] <= centers[other_id][1] + else 1.0 + ) + amount = overlap_y + shift = 0.5 * (amount + 1e-6) * overlap_strength + centers[object_id][axis] += sign * shift + centers[other_id][axis] -= sign * shift + max_delta = max(max_delta, shift) + + for edge in neighbor_edges: + object_id = edge["object"] + neighbor_id = edge["neighbor"] + initial_delta = np.asarray(edge["initial_delta"], dtype=np.float64) + error = (centers[object_id] - centers[neighbor_id]) - initial_delta + correction = 0.5 * neighbor_strength * error + centers[object_id] -= correction + centers[neighbor_id] += correction + max_delta = max(max_delta, float(np.linalg.norm(correction))) + + max_delta = max( + max_delta, + _apply_compactness_pull( + centers=centers, + xy_sizes=xy_sizes, + initial_union_bounds=initial_union_bounds, + target_expansion_ratio=target_expansion_ratio, + strength=compactness_strength, + ), + ) + + centers = _center_xy_aabb_layout( + centers=centers, + xy_sizes=xy_sizes, + ) + if iteration >= 20 and max_delta < 1e-5: + break + + diagnostics = _footprint_layout_diagnostics( + object_ids=object_ids, + centers=centers, + initial_centers=initial_centers, + xy_sizes=xy_sizes, + padding=padding, + initial_union_bounds=initial_union_bounds, + ) + metadata = { + "method": "geometry_knn_2d_aabb_relaxation", + "relation_usage": "disabled", + "iterations": iterations, + "padding": padding, + "padding_ratio": padding_ratio, + "max_iterations": max_iterations, + "overlap_strength": overlap_strength, + "neighbor_strength": neighbor_strength, + "compactness_strength": compactness_strength, + "target_expansion_ratio": target_expansion_ratio, + "knn_k": knn_k, + "neighbor_edges": neighbor_edges, + "final_centers": { + object_id: centers[object_id].tolist() + for object_id in object_ids + }, + **diagnostics, + } + return {"centers": centers, "metadata": metadata} + + + +def _knn_neighbor_edges( + *, + centers: dict[str, np.ndarray], + k: int, +) -> list[dict[str, Any]]: + if k <= 0 or len(centers) < 2: + return [] + object_ids = sorted(centers) + edges: list[dict[str, Any]] = [] + seen: set[tuple[str, str]] = set() + for object_id in object_ids: + distances = [] + for other_id in object_ids: + if other_id == object_id: + continue + distance = float(np.linalg.norm(centers[object_id] - centers[other_id])) + distances.append((distance, other_id)) + for _, neighbor_id in sorted(distances)[:k]: + edge_key = tuple(sorted((object_id, neighbor_id))) + if edge_key in seen: + continue + seen.add(edge_key) + edges.append( + { + "object": object_id, + "neighbor": neighbor_id, + "initial_delta": ( + centers[object_id] - centers[neighbor_id] + ).tolist(), + } + ) + return edges + + + +def _apply_compactness_pull( + *, + centers: dict[str, np.ndarray], + xy_sizes: dict[str, np.ndarray], + initial_union_bounds: np.ndarray, + target_expansion_ratio: float, + strength: float, +) -> float: + current_bounds = _xy_union_bounds(centers=centers, xy_sizes=xy_sizes) + expansion_ratio = _xy_union_area(current_bounds) / max( + _xy_union_area(initial_union_bounds), + 1.0e-12, + ) + if expansion_ratio <= target_expansion_ratio: + return 0.0 + excess = min(expansion_ratio / target_expansion_ratio - 1.0, 1.0) + union_center = 0.5 * (current_bounds[0] + current_bounds[1]) + factor = strength * excess + max_delta = 0.0 + for object_id, center in centers.items(): + delta = factor * (union_center - center) + centers[object_id] = center + delta + max_delta = max(max_delta, float(np.linalg.norm(delta))) + return max_delta + + + +def _footprint_layout_diagnostics( + *, + object_ids: list[str], + centers: dict[str, np.ndarray], + initial_centers: dict[str, np.ndarray], + xy_sizes: dict[str, np.ndarray], + padding: float, + initial_union_bounds: np.ndarray, +) -> dict[str, Any]: + remaining_overlaps = _remaining_xy_overlaps( + object_ids=object_ids, + centers=centers, + xy_sizes=xy_sizes, + padding=padding, + ) + displacements = [ + float(np.linalg.norm(centers[object_id] - initial_centers[object_id])) + for object_id in object_ids + ] + current_union_bounds = _xy_union_bounds(centers=centers, xy_sizes=xy_sizes) + expansion_ratio = _xy_union_area(current_union_bounds) / max( + _xy_union_area(initial_union_bounds), + 1.0e-12, + ) + average_displacement = float(np.mean(displacements)) if displacements else 0.0 + max_displacement = float(np.max(displacements)) if displacements else 0.0 + confidence_score = _footprint_confidence_score( + remaining_overlap_count=len(remaining_overlaps), + average_displacement=average_displacement, + max_extent=max( + float(max(xy_sizes[object_id][0], xy_sizes[object_id][1])) + for object_id in object_ids + ) + if object_ids + else 1.0, + expansion_ratio=expansion_ratio, + ) + return { + "remaining_overlaps": remaining_overlaps, + "average_displacement": average_displacement, + "max_displacement": max_displacement, + "union_aabb_expansion_ratio": expansion_ratio, + "confidence_score": confidence_score, + } + + + +def _remaining_xy_overlaps( + *, + object_ids: list[str], + centers: dict[str, np.ndarray], + xy_sizes: dict[str, np.ndarray], + padding: float, +) -> list[dict[str, Any]]: + overlaps: list[dict[str, Any]] = [] + for index, object_id in enumerate(object_ids): + for other_id in object_ids[index + 1 :]: + overlap = _xy_aabb_overlap( + center_a=centers[object_id], + size_a=xy_sizes[object_id], + center_b=centers[other_id], + size_b=xy_sizes[other_id], + padding=padding, + ) + if overlap is None: + continue + overlaps.append( + { + "object": object_id, + "other": other_id, + "overlap_x": overlap[0], + "overlap_y": overlap[1], + } + ) + return overlaps + + + +def _footprint_confidence_score( + *, + remaining_overlap_count: int, + average_displacement: float, + max_extent: float, + expansion_ratio: float, +) -> float: + displacement_scale = max(max_extent, 1.0e-6) + overlap_penalty = min(0.35 * remaining_overlap_count, 0.7) + displacement_penalty = min(0.1 * average_displacement / displacement_scale, 0.2) + expansion_penalty = min(max(expansion_ratio - 1.2, 0.0) * 0.25, 0.2) + return float( + np.clip( + 1.0 + - overlap_penalty + - displacement_penalty + - expansion_penalty, + 0.0, + 1.0, + ) + ) + + + +def _center_xy_aabb_layout( + *, + centers: dict[str, np.ndarray], + xy_sizes: dict[str, np.ndarray], +) -> dict[str, np.ndarray]: + if not centers: + return centers + bounds_min = [] + bounds_max = [] + for object_id, center in centers.items(): + half_size = 0.5 * np.asarray(xy_sizes[object_id], dtype=np.float64) + bounds_min.append(center - half_size) + bounds_max.append(center + half_size) + clutter_center = 0.5 * ( + np.vstack(bounds_min).min(axis=0) + + np.vstack(bounds_max).max(axis=0) + ) + return { + object_id: np.asarray(center, dtype=np.float64) - clutter_center + for object_id, center in centers.items() + } + + + +def _xy_union_bounds( + *, + centers: dict[str, np.ndarray], + xy_sizes: dict[str, np.ndarray], +) -> np.ndarray: + if not centers: + return np.zeros((2, 2), dtype=np.float64) + bounds_min = [] + bounds_max = [] + for object_id, center in centers.items(): + half_size = 0.5 * np.asarray(xy_sizes[object_id], dtype=np.float64) + bounds_min.append(np.asarray(center, dtype=np.float64) - half_size) + bounds_max.append(np.asarray(center, dtype=np.float64) + half_size) + return np.vstack( + [ + np.vstack(bounds_min).min(axis=0), + np.vstack(bounds_max).max(axis=0), + ] + ) + + + +def _xy_union_area(bounds: np.ndarray) -> float: + bounds = np.asarray(bounds, dtype=np.float64) + size = np.maximum(bounds[1] - bounds[0], 1.0e-9) + return float(size[0] * size[1]) + + + +def _xy_aabb_overlap( + *, + center_a: np.ndarray, + size_a: np.ndarray, + center_b: np.ndarray, + size_b: np.ndarray, + padding: float, +) -> tuple[float, float] | None: + half_a = 0.5 * np.asarray(size_a, dtype=np.float64) + half_b = 0.5 * np.asarray(size_b, dtype=np.float64) + delta = np.abs( + np.asarray(center_b, dtype=np.float64) + - np.asarray(center_a, dtype=np.float64) + ) + overlap = half_a + half_b + padding - delta + if float(overlap[0]) <= 0.0 or float(overlap[1]) <= 0.0: + return None + return float(overlap[0]), float(overlap[1]) +# http://www.apache.org/licenses/LICENSE-2.0 +# distributed under the License is distributed on an "AS IS" BASIS, + + + +from typing import Any + +import numpy as np + +__all__: list[str] = [] + +def _transitive_closure( + nodes: list[str], + edges: list[tuple[str, str]], +) -> list[tuple[str, str]]: + """Floyd–Warshall transitive closure over a small set of nodes.""" + if not nodes or not edges: + return list(edges) + idx = {n: i for i, n in enumerate(nodes)} + n = len(nodes) + adj = [[False] * n for _ in range(n)] + for src, dst in edges: + if src in idx and dst in idx: + adj[idx[src]][idx[dst]] = True + for k in range(n): + for i in range(n): + if adj[i][k]: + row_k = adj[k] + row_i = adj[i] + for j in range(n): + if row_k[j]: + row_i[j] = True + closed: list[tuple[str, str]] = [] + for i in range(n): + for j in range(n): + if adj[i][j]: + closed.append((nodes[i], nodes[j])) + return closed + + + +def _longest_path_ranks( + nodes: list[str], + edges: list[tuple[str, str]], +) -> dict[str, int]: + """Assign integer ranks satisfying ``(A,B)`` → rank[A] < rank[B]. + + Uses topological sort + longest-path DP. Returns a rank dict for every + node in *nodes* (default 0 for isolated nodes). + """ + ranks: dict[str, int] = {n: 0 for n in nodes} + if not edges: + return ranks + # Build adjacency and in-degree + adj: dict[str, list[str]] = {n: [] for n in nodes} + in_deg: dict[str, int] = {n: 0 for n in nodes} + present = set(nodes) + for src, dst in edges: + if src not in present or dst not in present: + continue + adj[src].append(dst) + in_deg[dst] += 1 + # Kahn topological sort + queue = [n for n in nodes if in_deg[n] == 0] + order: list[str] = [] + while queue: + u = queue.pop(0) + order.append(u) + for v in adj[u]: + in_deg[v] -= 1 + if in_deg[v] == 0: + queue.append(v) + # Longest path + for u in order: + for v in adj[u]: + if ranks[v] < ranks[u] + 1: + ranks[v] = ranks[u] + 1 + # Remaining nodes (cycles / isolated) keep rank 0 + return ranks + + + +def _layout_text_objects_grid( + *, + object_ids: list[str], + xy_sizes: dict[str, np.ndarray], + spatial_relations: list[dict[str, Any]], + table_constraints: list[dict[str, Any]] | None = None, + grid_spacing: float = 0.02, + padding_ratio: float = 0.08, +) -> dict[str, Any]: + """Lay out text-scene objects — transitive closure + longest-path ranks. + + 1. Transitive closure of left_of / front_of. + 2. Pick centre: explicit 9‑grid ʻcenterʼ, else highest-degree node. + 3. Longest-path rank assignment (left_of→X, front_of→Y). + 4. Shift 9‑grid anchors to their grid positions. + 5. Free objects auto‑wrap below. + 6. Convert ranks→XY using per‑column/row max sizes + gaps. + 7. SA point optimisation + mesh AABB collision cleanup. + """ + if not object_ids: + return { + "centers": {}, + "initial_centers": {}, + "metadata": { + "method": "transitive_closure_longest_path_with_9grid", + "iterations": 0, + }, + } + + # Parse spatial relations. + left_of_edges: list[tuple[str, str]] = [] + front_of_edges: list[tuple[str, str]] = [] + seen: set[tuple[str, str, str]] = set() + for rel in spatial_relations: + subject = str(rel.get("subject") or "") + obj = str(rel.get("object") or "") + relation = str(rel.get("relation") or "") + if not subject or not obj or subject == obj: + continue + key = (subject, relation, obj) + if key in seen: + continue + seen.add(key) + if relation == "left_of": + left_of_edges.append((subject, obj)) + elif relation == "front_of": + front_of_edges.append((subject, obj)) + + # Compute transitive closures. + left_of_closed = _transitive_closure(object_ids, left_of_edges) + front_of_closed = _transitive_closure(object_ids, front_of_edges) + + # Parse nine-grid constraints. + # −Y = front, so front row = 0, back row = 2 + _GRID_TO_RC: dict[str, tuple[int, int]] = { + "left_front": (0, 0), "center_front": (1, 0), "right_front": (2, 0), + "left_center": (0, 1), "center": (1, 1), "right_center": (2, 1), + "left_back": (0, 2), "center_back": (1, 2), "right_back": (2, 2), + "front": (1, 0), "back": (1, 2), + "left": (0, 1), "right": (2, 1), + } + grid_targets: dict[str, tuple[int, int]] = {} + for tc in (table_constraints or []): + asset = str(tc.get("asset") or "") + grid_name = str(tc.get("grid") or "").strip() + if asset in object_ids and grid_name in _GRID_TO_RC: + grid_targets[asset] = _GRID_TO_RC[grid_name] + + # Select a center object when none is explicit. + auto_center_oid: str | None = None + has_explicit_center = any( + tc.get("grid") == "center" for tc in (table_constraints or []) + ) + if not has_explicit_center: + # Degree = appearances in left_of + front_of (subject or object) + degree: dict[str, int] = {oid: 0 for oid in object_ids} + for src, dst in left_of_closed + front_of_closed: + if src in degree: + degree[src] += 1 + if dst in degree: + degree[dst] += 1 + max_deg = max(degree.values()) if degree else 0 + if max_deg > 0: + candidates = [oid for oid, d in degree.items() if d == max_deg] + # Tie-breaker: largest AABB area + centre_oid = max( + candidates, + key=lambda oid: float(xy_sizes[oid][0]) * float(xy_sizes[oid][1]), + ) + grid_targets[centre_oid] = (1, 1) # 9‑grid centre + auto_center_oid = centre_oid + + # Derive ranks from the transitive closures. + x_rank = _longest_path_ranks(object_ids, left_of_closed) + # −Y = front: A front_of B → A.y < B.y → row[A] < row[B]. + # _longest_path_ranks gives rank[src] < rank[dst]; edges are + # already (A,B) for "A front_of B", so NO reversal needed. + y_rank = _longest_path_ranks(object_ids, front_of_closed) + + # Apply nine-grid shifts. + # Pin 9‑grid objects to their target ranks; shift all connected + # objects (both upstream and downstream) to preserve topology. + if grid_targets: + # Build undirected connected-components via relation edges + all_edges = left_of_closed + front_of_closed + neighbours: dict[str, set[str]] = {oid: set() for oid in object_ids} + for src, dst in all_edges: + if src in neighbours and dst in neighbours: + neighbours[src].add(dst) + neighbours[dst].add(src) + for oid in grid_targets: + neighbours.setdefault(oid, set()) + + # For each 9‑grid object, BFS the component and shift uniformly + shifted: set[str] = set() + for oid, (target_col, target_row) in grid_targets.items(): + if oid in shifted: + continue + dx = target_col - x_rank.get(oid, 0) + dy = target_row - y_rank.get(oid, 0) + + # BFS to collect the full connected component + component: set[str] = {oid} + queue = [oid] + while queue: + u = queue.pop(0) + for v in neighbours.get(u, set()): + if v not in component: + component.add(v) + queue.append(v) + + for oid2 in component: + if oid2 not in grid_targets: # only shift non‑anchored objects + x_rank[oid2] = x_rank.get(oid2, 0) + dx + y_rank[oid2] = y_rank.get(oid2, 0) + dy + shifted.update(component) + + # Propagate row and column alignment. + # left_of A B → same row (y_rank[A] = y_rank[B]) + # front_of A B → same col (x_rank[A] = x_rank[B]) + # Priority (higher wins): 9‑grid > higher degree > larger area. + _prio = { + oid: ( + oid in grid_targets, + sum(1 for e in left_of_closed + front_of_closed if oid in e), + float(xy_sizes[oid][0]) * float(xy_sizes[oid][1]), + ) + for oid in object_ids + } + for src, dst in left_of_closed: + if _prio[src] >= _prio[dst]: + y_rank[dst] = y_rank.get(src, 0) + else: + y_rank[src] = y_rank.get(dst, 0) + for src, dst in front_of_closed: + if _prio[src] >= _prio[dst]: + x_rank[dst] = x_rank.get(src, 0) + else: + x_rank[src] = x_rank.get(dst, 0) + + # Normalise to >= 0 + min_x = min(x_rank.values()) if x_rank else 0 + min_y = min(y_rank.values()) if y_rank else 0 + for oid in object_ids: + x_rank[oid] = x_rank.get(oid, 0) - min_x + y_rank[oid] = y_rank.get(oid, 0) - min_y + + # Resolve cell collisions: spread objects sharing the same (col, row) + cell_occupants: dict[tuple[int, int], list[str]] = {} + for oid in object_ids: + cell = (x_rank[oid], y_rank[oid]) + cell_occupants.setdefault(cell, []).append(oid) + for (col, row), occupants in cell_occupants.items(): + if len(occupants) > 1: + for offset, oid in enumerate(occupants[1:], start=1): + x_rank[oid] = col + offset + + # Place unconstrained objects in wrapped rows. + constrained = set() + for src, dst in left_of_closed + front_of_closed: + constrained.update([src, dst]) + constrained.update(grid_targets) + free_objects = [oid for oid in object_ids if oid not in constrained] + + if free_objects: + free_row = max(y_rank.values()) + 1 if y_rank else 0 + # Max row width ≈ existing union width × 1.5 (at least 3 cols) + col_keys = list(x_rank.values()) + existing_cols = max(col_keys) - min(col_keys) + 1 if col_keys else 1 + max_cols_per_row = max(existing_cols, 3) + free_sorted = sorted( + free_objects, + key=lambda oid: float(xy_sizes[oid][0]), + reverse=True, + ) + col = 0 + row_offset = 0 + for oid in free_sorted: + x_rank[oid] = col + y_rank[oid] = free_row + row_offset + col += 1 + if col >= max_cols_per_row: + col = 0 + row_offset += 1 + + # Convert ranks to XY positions. + col_widths: dict[int, float] = {} + row_heights: dict[int, float] = {} + for oid in object_ids: + c = x_rank[oid] + r = y_rank[oid] + col_widths[c] = max(col_widths.get(c, 0.0), float(xy_sizes[oid][0])) + row_heights[r] = max(row_heights.get(r, 0.0), float(xy_sizes[oid][1])) + + x_cumsum: dict[int, float] = {} + cumulative = 0.0 + for c in sorted(col_widths): + x_cumsum[c] = cumulative + cumulative += col_widths[c] + grid_spacing + + y_cumsum: dict[int, float] = {} + cumulative = 0.0 + for r in sorted(row_heights): + y_cumsum[r] = cumulative + cumulative += row_heights[r] + grid_spacing + + centers: dict[str, np.ndarray] = {} + for oid in object_ids: + c = x_rank[oid] + r = y_rank[oid] + cx = x_cumsum[c] + 0.5 * float(xy_sizes[oid][0]) + cy = y_cumsum[r] + 0.5 * float(xy_sizes[oid][1]) + centers[oid] = np.array([cx, cy], dtype=np.float64) + + centers = _center_xy_aabb_layout(centers=centers, xy_sizes=xy_sizes) + + initial_centers = {oid: c.copy() for oid, c in centers.items()} + + # Snap initial grid positions as 9‑grid spring targets + grid_spring_targets: dict[str, np.ndarray] = { + oid: initial_centers[oid].copy() + for oid in grid_targets + if oid in initial_centers + } + + # Optimize positions and remove mesh AABB collisions. + optimized = _optimize_text_layout_slp( + object_ids=object_ids, + xy_sizes=xy_sizes, + initial_centers=initial_centers, + left_of_edges=left_of_closed, + front_of_edges=front_of_closed, + grid_spring_targets=grid_spring_targets, + padding_ratio=padding_ratio, + ) + centers = optimized["centers"] + optimization_metadata = optimized["metadata"] + + # Collect layout metadata. + metadata = { + "method": "transitive_closure_longest_path_with_9grid_and_sa", + "grid_spacing": grid_spacing, + "auto_center_oid": auto_center_oid, + "has_explicit_center": has_explicit_center, + "table_constraint_count": len(grid_targets), + "left_of_count": len(left_of_edges), + "left_of_closed_count": len(left_of_closed), + "front_of_count": len(front_of_edges), + "front_of_closed_count": len(front_of_closed), + "free_object_count": len(free_objects), + "x_ranks": {oid: x_rank.get(oid, 0) for oid in object_ids}, + "y_ranks": {oid: y_rank.get(oid, 0) for oid in object_ids}, + "optimization": optimization_metadata, + } + return { + "centers": centers, + "initial_centers": initial_centers, + "metadata": metadata, + } +# http://www.apache.org/licenses/LICENSE-2.0 + +def _optimize_text_layout_slp( + *, + object_ids: list[str], + xy_sizes: dict[str, np.ndarray], + initial_centers: dict[str, np.ndarray], + left_of_edges: list[tuple[str, str]], + front_of_edges: list[tuple[str, str]], + grid_spring_targets: dict[str, np.ndarray], + padding_ratio: float, +) -> dict[str, Any]: + """Optimize 2D centres with scipy SLSQP, then remove mesh AABB overlap. + + Mirroring the original example_optimization/SA pipeline: + - left_of / front_of → linear inequality constraints + - bounding box → variable bounds (2× initial union) + - seed / overlap / grid → soft penalties in the objective + - post‑solve collision cleanup on actual footprint AABBs + """ + if not object_ids: + return { + "centers": {}, + "metadata": { + "method": "text_slsqp_then_mesh_aabb_collision_removal", + "slsqp_iterations": 0, + "collision_iterations": 0, + }, + } + + max_extent = max( + float(max(xy_sizes[oid][0], xy_sizes[oid][1])) for oid in object_ids + ) + padding = max(max_extent * padding_ratio, 1e-3) + + initial_centers = { + oid: np.asarray(initial_centers[oid], dtype=np.float64).copy() + for oid in object_ids + } + initial_union_bounds = _xy_union_bounds( + centers=initial_centers, + xy_sizes=xy_sizes, + ) + + index_by_id = {oid: i for i, oid in enumerate(object_ids)} + x0 = _pack_centers(object_ids, initial_centers) + + # Build linear inequality constraints for left_of and front_of. + constraints: list[dict[str, Any]] = [] + _build_relation_constraints( + constraints=constraints, + object_ids=object_ids, + index_by_id=index_by_id, + xy_sizes=xy_sizes, + left_of_edges=left_of_edges, + front_of_edges=front_of_edges, + padding=padding, + ) + + # Bound variables to twice the initial union size. + init_size = initial_union_bounds[1] - initial_union_bounds[0] + margin = init_size * 0.5 # 50 % each side → 2× total + bounds = [] + for oid in object_ids: + bounds.append( + ( + float(initial_union_bounds[0, 0] - margin[0]), + float(initial_union_bounds[1, 0] + margin[0]), + ) + ) # x + bounds.append( + ( + float(initial_union_bounds[0, 1] - margin[1]), + float(initial_union_bounds[1, 1] + margin[1]), + ) + ) # y + + # Define the optimization objective. + def _objective(xvec: np.ndarray) -> float: + centers = _unpack_centers(object_ids, xvec) + loss = 0.0 + + # seed: stay close to initial positions + for oid in object_ids: + delta = centers[oid] - initial_centers[oid] + loss += _WEIGHTS["seed"] * float(np.dot(delta, delta)) + + # overlap: AABB overlap area penalty + for i, oid in enumerate(object_ids): + for other_id in object_ids[i + 1 :]: + ov = _xy_aabb_overlap( + center_a=centers[oid], + size_a=xy_sizes[oid], + center_b=centers[other_id], + size_b=xy_sizes[other_id], + padding=padding, + ) + if ov is not None: + loss += _WEIGHTS["overlap"] * float(ov[0] * ov[1]) + + # grid: spring toward 9‑grid targets + for oid, target in grid_spring_targets.items(): + if oid not in centers: + continue + delta = centers[oid] - target + loss += _WEIGHTS["grid"] * float(np.dot(delta, delta)) + + return float(loss) + + # Solve the constrained optimization problem. + slsqp_result: dict[str, Any] = {"success": False, "nit": 0, "message": ""} + try: + result = minimize( + _objective, + x0, + method="SLSQP", + bounds=bounds, + constraints=constraints, + options=_SLSQP_OPTIONS, + ) + slsqp_result = { + "success": bool(result.success), + "nit": int(getattr(result, "nit", 0)), + "message": str(result.message), + "fun": float(result.fun) if result.fun is not None else None, + } + if result.success: + x_opt = result.x + else: + # SLSQP failed — fall back to seed positions + x_opt = x0.copy() + except Exception: + x_opt = x0.copy() + slsqp_result["message"] = "SLSQP raised an exception; using seed positions." + + centers = _unpack_centers(object_ids, x_opt) + centers = _center_xy_aabb_layout(centers=centers, xy_sizes=xy_sizes) + + # Remove residual collisions. + centers, collision_metadata = _remove_mesh_aabb_collisions( + object_ids=object_ids, + xy_sizes=xy_sizes, + centers=centers, + initial_centers=initial_centers, + left_of_edges=left_of_edges, + front_of_edges=front_of_edges, + padding=padding, + ) + centers = _center_xy_aabb_layout(centers=centers, xy_sizes=xy_sizes) + + # Collect optimization metadata. + diagnostics = _footprint_layout_diagnostics( + object_ids=object_ids, + centers=centers, + initial_centers=initial_centers, + xy_sizes=xy_sizes, + padding=padding, + initial_union_bounds=initial_union_bounds, + ) + metadata: dict[str, Any] = { + "method": "text_slsqp_then_mesh_aabb_collision_removal", + "relation_usage": "left_of_front_of_hard_constraints", + "padding": float(padding), + "padding_ratio": float(padding_ratio), + "weights": dict(_WEIGHTS), + "slsqp": slsqp_result, + "bounds_expansion": 2.0, + "initial_union_size": init_size.tolist(), + **collision_metadata, + "final_centers": { + oid: centers[oid].tolist() for oid in object_ids + }, + **diagnostics, + } + return {"centers": centers, "metadata": metadata} + + +# Build relation constraints. + + +def _build_relation_constraints( + *, + constraints: list[dict[str, Any]], + object_ids: list[str], + index_by_id: dict[str, int], + xy_sizes: dict[str, np.ndarray], + left_of_edges: list[tuple[str, str]], + front_of_edges: list[tuple[str, str]], + padding: float, +) -> None: + """Append SLSQP inequality constraints for left_of / front_of edges.""" + + for subject, obj in left_of_edges: + if subject not in index_by_id or obj not in index_by_id: + continue + i_a = index_by_id[subject] + i_b = index_by_id[obj] + # A.x + gap ≤ B.x → B.x - A.x - gap ≥ 0 + gap = ( + 0.5 * float(xy_sizes[subject][0]) + + 0.5 * float(xy_sizes[obj][0]) + + padding + ) + constraints.append( + { + "type": "ineq", + "fun": lambda x, ia=i_a, ib=i_b, g=gap: float( + x[2 * ib] - x[2 * ia] - g + ), + } + ) + + for subject, obj in front_of_edges: + if subject not in index_by_id or obj not in index_by_id: + continue + i_a = index_by_id[subject] + i_b = index_by_id[obj] + # A.y + gap ≤ B.y → B.y - A.y - gap ≥ 0 + gap = ( + 0.5 * float(xy_sizes[subject][1]) + + 0.5 * float(xy_sizes[obj][1]) + + padding + ) + constraints.append( + { + "type": "ineq", + "fun": lambda x, ia=i_a, ib=i_b, g=gap: float( + x[2 * ib + 1] - x[2 * ia + 1] - g + ), + } + ) + + +# Remove AABB collisions. + + +def _remove_mesh_aabb_collisions( + *, + object_ids: list[str], + xy_sizes: dict[str, np.ndarray], + centers: dict[str, np.ndarray], + initial_centers: dict[str, np.ndarray], + left_of_edges: list[tuple[str, str]], + front_of_edges: list[tuple[str, str]], + padding: float, +) -> tuple[dict[str, np.ndarray], dict[str, Any]]: + relation_pairs = set(left_of_edges + front_of_edges) + relation_pairs.update((b, a) for a, b in left_of_edges + front_of_edges) + current = { + oid: np.asarray(center, dtype=np.float64).copy() + for oid, center in centers.items() + } + max_rounds = 80 + total_pushes = 0 + last_overlap_count = 0 + + for iteration in range(max_rounds): + overlaps = _mesh_aabb_collision_pairs( + object_ids=object_ids, + xy_sizes=xy_sizes, + centers=current, + padding=padding, + ) + last_overlap_count = len(overlaps) + if not overlaps: + return current, { + "collision_iterations": iteration, + "collision_pushes": total_pushes, + "collision_remaining": 0, + "collision_removal": "iterative_mesh_aabb_push", + } + for item in overlaps: + object_a = item["object"] + object_b = item["other"] + axis = int(item["axis"]) + sign = -1.0 if current[object_a][axis] <= current[object_b][axis] else 1.0 + amount = 0.5 * (float(item["overlap"]) + 1.0e-6) + if (object_a, object_b) in relation_pairs: + current[object_a][axis] += sign * amount + current[object_b][axis] -= sign * amount + else: + drift_a = np.linalg.norm( + current[object_a] - initial_centers[object_a] + ) + drift_b = np.linalg.norm( + current[object_b] - initial_centers[object_b] + ) + if drift_a <= drift_b: + current[object_a][axis] += sign * amount * 1.25 + current[object_b][axis] -= sign * amount * 0.75 + else: + current[object_a][axis] += sign * amount * 0.75 + current[object_b][axis] -= sign * amount * 1.25 + total_pushes += 1 + current = _center_xy_aabb_layout(centers=current, xy_sizes=xy_sizes) + + return current, { + "collision_iterations": max_rounds, + "collision_pushes": total_pushes, + "collision_remaining": last_overlap_count, + "collision_removal": "iterative_mesh_aabb_push", + } + + +def _mesh_aabb_collision_pairs( + *, + object_ids: list[str], + xy_sizes: dict[str, np.ndarray], + centers: dict[str, np.ndarray], + padding: float, +) -> list[dict[str, Any]]: + pairs: list[dict[str, Any]] = [] + for i, oid in enumerate(object_ids): + for other_id in object_ids[i + 1 :]: + ov = _xy_aabb_overlap( + center_a=centers[oid], + size_a=xy_sizes[oid], + center_b=centers[other_id], + size_b=xy_sizes[other_id], + padding=padding, + ) + if ov is None: + continue + axis = 0 if ov[0] <= ov[1] else 1 + pairs.append( + { + "object": oid, + "other": other_id, + "axis": axis, + "overlap": float(ov[axis]), + "overlap_x": float(ov[0]), + "overlap_y": float(ov[1]), + } + ) + pairs.sort(key=lambda item: item["overlap"], reverse=True) + return pairs + + +# Pack and unpack center coordinates. + + +def _pack_centers( + object_ids: list[str], + centers: dict[str, np.ndarray], +) -> np.ndarray: + values: list[float] = [] + for oid in object_ids: + c = np.asarray(centers[oid], dtype=np.float64) + values.extend([float(c[0]), float(c[1])]) + return np.asarray(values, dtype=np.float64) + + +def _unpack_centers( + object_ids: list[str], + xvec: np.ndarray, +) -> dict[str, np.ndarray]: + return { + oid: np.asarray( + [xvec[2 * i], xvec[2 * i + 1]], + dtype=np.float64, + ) + for i, oid in enumerate(object_ids) + } diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/managers/metric_scale_manager/__init__.py b/embodichain/gen_sim/prompt2scene/agent_tools/managers/metric_scale_manager/__init__.py deleted file mode 100644 index 8eca3510d..000000000 --- a/embodichain/gen_sim/prompt2scene/agent_tools/managers/metric_scale_manager/__init__.py +++ /dev/null @@ -1,37 +0,0 @@ -# ---------------------------------------------------------------------------- -# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ---------------------------------------------------------------------------- - -from __future__ import annotations - -from embodichain.gen_sim.prompt2scene.agent_tools.managers.metric_scale_manager.manager import ( - METRIC_SCALE_ENABLED, - MetricScaleManager, -) -from embodichain.gen_sim.prompt2scene.agent_tools.managers.metric_scale_manager.schemas import ( - EstimateMetricScalesRequest, - EstimateMetricScalesResult, - GlobalMetricScaleRequest, - MetricScaleObjectInput, -) - -__all__ = [ - "METRIC_SCALE_ENABLED", - "EstimateMetricScalesRequest", - "EstimateMetricScalesResult", - "GlobalMetricScaleRequest", - "MetricScaleManager", - "MetricScaleObjectInput", -] diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/managers/metric_scale_manager/manager.py b/embodichain/gen_sim/prompt2scene/agent_tools/managers/metric_scale_manager/manager.py deleted file mode 100644 index ce1d47e9a..000000000 --- a/embodichain/gen_sim/prompt2scene/agent_tools/managers/metric_scale_manager/manager.py +++ /dev/null @@ -1,431 +0,0 @@ -# ---------------------------------------------------------------------------- -# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ---------------------------------------------------------------------------- - -from __future__ import annotations - -from typing import Any - -import numpy as np - -from embodichain.gen_sim.prompt2scene.agent_tools.managers.geometry_manager import ( - GeometryManager, - LoadMeshRequest, -) -from embodichain.gen_sim.prompt2scene.agent_tools.managers.metric_scale_manager.schemas import ( - EstimateMetricScalesRequest, - EstimateMetricScalesResult, - GlobalMetricScaleRequest, - MetricScaleObjectInput, -) -from embodichain.gen_sim.prompt2scene.utils.io import write_json -from embodichain.gen_sim.prompt2scene.workflows.llm_output import ( - call_structured_json_model_step, -) - -__all__ = ["METRIC_SCALE_ENABLED", "MetricScaleManager"] - -METRIC_SCALE_ENABLED = True - - -class MetricScaleManager: - """Manager for metric scale estimation and scale aggregation.""" - - @staticmethod - def estimate_metric_scales( - request: EstimateMetricScalesRequest, - ) -> EstimateMetricScalesResult: - """Call an LLM and convert bbox-size predictions into scale factors.""" - object_payload = MetricScaleManager.build_object_payload(request.objects) - raw_model_output_path = ( - request.raw_output_path.expanduser().resolve() - if request.raw_output_path is not None - else None - ) - raw_model_output = call_structured_json_model_step( - llm=request.llm, - schema=request.schema, - messages=request.messages, - context=request.context, - step_name=request.step_name, - output_root=None, - attempt_count=0, - raw_output_writer=( - (lambda payload: write_json(raw_model_output_path, payload)) - if raw_model_output_path is not None - else None - ), - ) - object_scales = MetricScaleManager.apply_model_output( - object_payload=object_payload, - raw_model_output=raw_model_output, - method=request.method, - ) - return EstimateMetricScalesResult( - status="ok", - object_scales=object_scales, - object_payload=object_payload, - raw_model_output=raw_model_output, - ) - - @staticmethod - def build_object_payload( - objects: list[MetricScaleObjectInput], - ) -> list[dict[str, Any]]: - """Build object payload with normalized mesh bbox measurements.""" - geom = GeometryManager() - payload: list[dict[str, Any]] = [] - for obj in objects: - mesh = geom.load_mesh(LoadMeshRequest(mesh_path=obj.mesh_path)).mesh - normalized_bbox_size_m = GeometryManager.mesh_aabb_size(mesh) - payload.append( - { - "object_id": obj.object_id, - "object_name": obj.object_name, - "object_description": obj.object_description, - "normalized_bbox_size_m": normalized_bbox_size_m.tolist(), - "normalized_bbox_ratio": GeometryManager.bbox_ratio( - normalized_bbox_size_m - ).tolist(), - } - ) - return payload - - @staticmethod - def object_prompt_payload( - objects: list[MetricScaleObjectInput], - ) -> list[dict[str, str]]: - """Return the lightweight object payload intended for LLM prompts.""" - return [ - { - "object_id": obj.object_id, - "object_name": obj.object_name, - "object_description": obj.object_description, - } - for obj in objects - ] - - @staticmethod - def apply_model_output( - *, - object_payload: list[dict[str, Any]], - raw_model_output: dict[str, Any], - method: str, - ) -> list[dict[str, Any]]: - """Convert model bbox predictions into per-object metric-scale records.""" - model_by_id = { - str(item.get("object_id", "")): item - for item in raw_model_output.get("object_scales", []) - if isinstance(item, dict) - } - estimates: list[dict[str, Any]] = [] - for payload in object_payload: - object_id = str(payload.get("object_id", "")) - model_item = model_by_id.get(object_id) - if model_item is None: - estimates.append( - MetricScaleManager.failure( - object_id=object_id, - reason="missing_object_scale_from_model", - method=method, - ) - ) - continue - estimates.append( - MetricScaleManager.select_candidate( - object_id=object_id, - object_name=str(payload.get("object_name", "")), - object_description=str(payload.get("object_description", "")), - bbox_dims_cm=model_item.get("bbox_dims_cm", []), - confidence=float(model_item.get("confidence", 0.0)), - reason=str(model_item.get("reason", "")), - normalized_bbox_size_m=np.asarray( - payload["normalized_bbox_size_m"], - dtype=np.float64, - ), - method=method, - ) - ) - return estimates - - @staticmethod - def apply_to_objects( - *, - objects: list[dict[str, Any]], - object_scales: list[dict[str, Any]], - ) -> None: - """Attach metric-scale records to object dictionaries by object id.""" - scale_by_id = {str(item.get("object_id", "")): item for item in object_scales} - for obj in objects: - object_id = str(obj.get("id", "")) - if object_id in scale_by_id: - obj["metric_scale"] = scale_by_id[object_id] - - @staticmethod - def select_candidate( - *, - object_id: str, - object_name: str, - object_description: str, - bbox_dims_cm: Any, - confidence: float, - reason: str, - normalized_bbox_size_m: np.ndarray, - method: str, - ) -> dict[str, Any]: - """Select a scale factor from predicted real-world bbox dimensions.""" - try: - selected = MetricScaleManager.compute_from_bbox_dims( - bbox_dims_cm=bbox_dims_cm, - confidence=confidence, - reason=reason, - normalized_bbox_size_m=normalized_bbox_size_m, - ) - except (TypeError, ValueError): - return MetricScaleManager.failure( - object_id=object_id, - reason="invalid_bbox_dims_cm", - method=method, - ) - normalized_bbox_size_cm = ( - np.asarray(normalized_bbox_size_m, dtype=np.float64) * 100.0 - ) - return { - "status": "ok", - "method": method, - "object_id": object_id, - "object_name": object_name, - "object_description": object_description, - "normalized_bbox_size_m": normalized_bbox_size_m.tolist(), - "normalized_bbox_size_cm": normalized_bbox_size_cm.tolist(), - "normalized_bbox_ratio": GeometryManager.bbox_ratio( - normalized_bbox_size_m - ).tolist(), - "bbox_dims_cm": selected["bbox_dims_cm"], - "axis_match": selected["axis_match"], - "scale_factor": selected["scale_factor"], - "confidence": selected["confidence"], - "reason": selected["reason"], - "unit_note": "scale_factor is not baked into this GLB.", - } - - @staticmethod - def compute_from_bbox_dims( - *, - bbox_dims_cm: Any, - confidence: float, - reason: str, - normalized_bbox_size_m: np.ndarray, - ) -> dict[str, Any]: - """Compute one scale candidate from model-predicted bbox dimensions.""" - dims_cm = np.asarray( - [float(value) for value in bbox_dims_cm], - dtype=np.float64, - ) - if dims_cm.shape != (3,) or np.any(dims_cm <= 0.0): - raise ValueError("bbox_dims_cm must contain three positive values.") - normalized_bbox_size_cm = ( - np.asarray(normalized_bbox_size_m, dtype=np.float64) * 100.0 - ) - axis_match = GeometryManager.best_axis_bbox_scale_match( - source_size_cm=normalized_bbox_size_cm, - target_size_cm=dims_cm, - ) - return { - "bbox_dims_cm": dims_cm.tolist(), - "axis_match": axis_match, - "scale_factor": float(axis_match["scale_factor"]), - "confidence": confidence, - "reason": reason, - } - - @staticmethod - def failure( - *, - object_id: str, - reason: str, - method: str, - ) -> dict[str, Any]: - """Build a failed per-object metric-scale record.""" - return { - "status": "failed", - "method": method, - "object_id": object_id, - "scale_factor": 1.0, - "reason": reason, - } - - @staticmethod - def set_for_all_objects( - *, - objects: list[dict[str, Any]], - status: str, - reason: str, - method: str, - ) -> None: - """Attach the same fallback metric-scale status to all objects.""" - for obj in objects: - obj["metric_scale"] = { - "status": status, - "method": method, - "object_id": str(obj.get("id", "")), - "scale_factor": 1.0, - "reason": reason, - } - - @staticmethod - def compute_global_from_object_scenes( - request: GlobalMetricScaleRequest, - ) -> dict[str, Any]: - """Aggregate object metric scales into one global scale for a scene layout.""" - if not METRIC_SCALE_ENABLED: - return { - "status": "disabled", - "method": "metric_scale_disabled", - "scale_factor": 1.0, - "object_count": len(request.objects), - "used_count": 0, - "skipped_count": len(request.objects), - "used": [], - "skipped": [ - {"id": str(item.get("id", "")), "reason": "metric_scale_disabled"} - for item in request.objects - ], - "unit_note": ( - "Metric scale is disabled; aligned GLBs keep simready " - "normalized size." - ), - } - - used: list[dict[str, Any]] = [] - skipped: list[dict[str, Any]] = [] - object_by_id = {str(item.get("id", "")): item for item in request.objects} - for object_id, scene in request.object_scenes: - item = object_by_id.get(object_id) - if item is None: - skipped.append({"id": object_id, "reason": "missing_object_record"}) - continue - metric_scale = item.get("metric_scale") - if not isinstance(metric_scale, dict): - skipped.append({"id": object_id, "reason": "missing_metric_scale"}) - continue - if metric_scale.get("status") != "ok": - skipped.append( - { - "id": object_id, - "reason": str(metric_scale.get("status") or "not_ok"), - } - ) - continue - - scale_factor_simready = float(metric_scale.get("scale_factor", 1.0)) - if not np.isfinite(scale_factor_simready) or scale_factor_simready <= 0.0: - skipped.append( - {"id": object_id, "reason": "invalid_simready_scale_factor"} - ) - continue - try: - simready_size_m = np.asarray( - [float(v) for v in metric_scale.get("normalized_bbox_size_m", [])], - dtype=np.float64, - ) - except (TypeError, ValueError): - skipped.append( - {"id": object_id, "reason": "invalid_normalized_bbox_size_m"} - ) - continue - if simready_size_m.shape != (3,) or np.any(simready_size_m <= 0.0): - skipped.append( - {"id": object_id, "reason": "invalid_normalized_bbox_size_m"} - ) - continue - - current_bounds = np.asarray(GeometryManager.scene_to_mesh(scene).bounds) - current_size_m = current_bounds[1] - current_bounds[0] - if current_size_m.shape != (3,) or np.any(current_size_m <= 0.0): - skipped.append({"id": object_id, "reason": "invalid_current_scene_aabb"}) - continue - - geo_ratio = np.sort(current_size_m) / np.sort(simready_size_m) - geo_scale = float(np.median(geo_ratio)) - if not np.isfinite(geo_scale) or geo_scale <= 0.0: - skipped.append({"id": object_id, "reason": "non_positive_geo_scale"}) - continue - - effective_scale = scale_factor_simready / geo_scale - if not np.isfinite(effective_scale) or effective_scale <= 0.0: - skipped.append( - {"id": object_id, "reason": "non_positive_effective_scale"} - ) - continue - - used.append( - { - "id": object_id, - "effective_scale": effective_scale, - "scale_factor_simready": scale_factor_simready, - "geo_scale": geo_scale, - "simready_bbox_size_m": simready_size_m.tolist(), - "simready_bbox_size_cm": (simready_size_m * 100.0).tolist(), - "current_scene_bbox_size_m": current_size_m.tolist(), - "current_scene_bbox_size_cm": (current_size_m * 100.0).tolist(), - "target_bbox_dims_cm": metric_scale.get("bbox_dims_cm"), - "confidence": metric_scale.get("confidence"), - } - ) - - if not used: - return { - "status": "fallback", - "method": "simready_reference_geo_ratio_mean_with_clamp", - "scale_factor": 1.0, - "raw_scale_factor": 1.0, - "was_clamped": False, - "clamp": {"min": request.min_scale, "max": request.max_scale}, - "object_count": len(request.objects), - "used_count": 0, - "skipped_count": len(skipped), - "used": [], - "skipped": skipped, - "unit_note": ( - "No valid metric scale was available; image clutter keeps the " - "SAM3D layout scale without an additional metric scale." - ), - } - - raw_scale_factor = float(np.mean([item["effective_scale"] for item in used])) - scale_factor = float( - np.clip(raw_scale_factor, request.min_scale, request.max_scale) - ) - return { - "status": "ok", - "method": "simready_reference_geo_ratio_mean_with_clamp", - "scale_factor": scale_factor, - "raw_scale_factor": raw_scale_factor, - "was_clamped": bool(scale_factor != raw_scale_factor), - "clamp": {"min": request.min_scale, "max": request.max_scale}, - "object_count": len(request.objects), - "used_count": len(used), - "skipped_count": len(skipped), - "used": used, - "skipped": skipped, - "unit_note": ( - "Global scale derived from scene-level VLM per-object scale_factor " - "divided by the geometric scale ratio between simready normalized " - "bbox and current aligned scene bbox (sorted, permutation-invariant). " - f"Aggregated via mean across objects, clamped to " - f"[{request.min_scale:.2f}, {request.max_scale:.2f}]." - ), - } diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/managers/metric_scale_manager/schemas.py b/embodichain/gen_sim/prompt2scene/agent_tools/managers/metric_scale_manager/schemas.py deleted file mode 100644 index dd2de3437..000000000 --- a/embodichain/gen_sim/prompt2scene/agent_tools/managers/metric_scale_manager/schemas.py +++ /dev/null @@ -1,73 +0,0 @@ -# ---------------------------------------------------------------------------- -# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ---------------------------------------------------------------------------- - -from __future__ import annotations - -from dataclasses import dataclass -from pathlib import Path -from typing import Any - -__all__ = [ - "EstimateMetricScalesRequest", - "EstimateMetricScalesResult", - "GlobalMetricScaleRequest", - "MetricScaleObjectInput", -] - - -@dataclass(frozen=True) -class MetricScaleObjectInput: - """Object input for metric-scale estimation.""" - - object_id: str - object_name: str - object_description: str - mesh_path: Path - - -@dataclass(frozen=True) -class EstimateMetricScalesRequest: - """Request to estimate metric scale for a set of normalized objects.""" - - objects: list[MetricScaleObjectInput] - messages: list[dict[str, Any]] - schema: dict[str, Any] - llm: Any - context: str - method: str - step_name: str = "metric_scale" - raw_output_path: Path | None = None - - -@dataclass(frozen=True) -class EstimateMetricScalesResult: - """Result of estimating metric scale for normalized objects.""" - - status: str - object_scales: list[dict[str, Any]] - object_payload: list[dict[str, Any]] - raw_model_output: dict[str, Any] | None = None - reason: str = "" - - -@dataclass(frozen=True) -class GlobalMetricScaleRequest: - """Request to aggregate per-object metric scales into one scene scale.""" - - objects: list[dict[str, Any]] - object_scenes: list[tuple[str, Any]] - min_scale: float = 0.10 - max_scale: float = 10.00 diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/managers/optimization_manager/__init__.py b/embodichain/gen_sim/prompt2scene/agent_tools/managers/optimization_manager/__init__.py deleted file mode 100644 index b61756bf0..000000000 --- a/embodichain/gen_sim/prompt2scene/agent_tools/managers/optimization_manager/__init__.py +++ /dev/null @@ -1,37 +0,0 @@ -# ---------------------------------------------------------------------------- -# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ---------------------------------------------------------------------------- - -from __future__ import annotations - -from embodichain.gen_sim.prompt2scene.agent_tools.managers.optimization_manager.manager import ( - _center_xy_aabb_layout, - _footprint_layout_diagnostics, - _object_scenes_xy_aabb_manifest, - _settle_and_pack_object_footprints, - _xy_aabb_overlap, - _xy_union_area, - _xy_union_bounds, -) - -__all__ = [ - "_center_xy_aabb_layout", - "_footprint_layout_diagnostics", - "_object_scenes_xy_aabb_manifest", - "_settle_and_pack_object_footprints", - "_xy_aabb_overlap", - "_xy_union_area", - "_xy_union_bounds", -] diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/managers/optimization_manager/manager.py b/embodichain/gen_sim/prompt2scene/agent_tools/managers/optimization_manager/manager.py deleted file mode 100644 index d7ed13484..000000000 --- a/embodichain/gen_sim/prompt2scene/agent_tools/managers/optimization_manager/manager.py +++ /dev/null @@ -1,633 +0,0 @@ -# ---------------------------------------------------------------------------- -# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ---------------------------------------------------------------------------- - - -from __future__ import annotations - -import tempfile -import traceback -from pathlib import Path -from typing import Any - -import numpy as np - -from embodichain.gen_sim.prompt2scene.agent_tools.managers.simulation_manager import ( - SimulationManager, -) -from embodichain.gen_sim.prompt2scene.agent_tools.managers.simulation_manager.schemas import ( - GravityDropRequest, -) -from embodichain.gen_sim.prompt2scene.agent_tools.managers.geometry_manager.scene_geometry import ( - _aabb_bottom_to_xy_plane_transform, - _copy_scene_with_transform, - _matrix_from_json, - _scene_to_mesh, - _xy_aabb_center, - _xy_aabb_size, - _z_up_to_glb_y_up_transform, -) -from embodichain.gen_sim.prompt2scene.utils.io import ( - relative_path, -) - -__all__ = [ - "_center_xy_aabb_layout", - "_object_scenes_xy_aabb_manifest", - "_settle_and_pack_object_footprints", - "_xy_aabb_overlap", - "_xy_union_area", - "_xy_union_bounds", -] - -def _object_scenes_xy_aabb_manifest( - *, - object_scenes: list[tuple[str, Any]], - trimesh: Any, - unit_scale: float, - unit: str, -) -> dict[str, Any]: - if not object_scenes: - return { - "status": "empty", - "unit": unit, - "object_count": 0, - } - bounds = [ - np.asarray(_scene_to_mesh(scene, trimesh=trimesh).bounds, dtype=np.float64) - for _, scene in object_scenes - ] - union_bounds = np.vstack( - [ - np.vstack([item[0] for item in bounds]).min(axis=0), - np.vstack([item[1] for item in bounds]).max(axis=0), - ] - ) - min_xy = union_bounds[0, :2] * unit_scale - max_xy = union_bounds[1, :2] * unit_scale - size_xy = max_xy - min_xy - center_xy = 0.5 * (min_xy + max_xy) - return { - "status": "ok", - "unit": unit, - "object_count": len(object_scenes), - "min_xy": min_xy.tolist(), - "max_xy": max_xy.tolist(), - "center_xy": center_xy.tolist(), - "size_xy": size_xy.tolist(), - "area": float(size_xy[0] * size_xy[1]), - } - - - -def _settle_and_pack_object_footprints( - *, - object_scenes: list[tuple[str, Any]], - output_dir: Path, - output_root: Path, - trimesh: Any, -) -> dict[str, Any]: - sim = SimulationManager(headless=True, sim_device="cpu") - footprint_items: list[dict[str, Any]] = [] - settled_entries: list[dict[str, Any]] = [] - output_axis_transform = _z_up_to_glb_y_up_transform() - output_to_internal_transform = np.linalg.inv(output_axis_transform) - - with tempfile.TemporaryDirectory(prefix="p2s_footprint_drop_") as tmp_dir: - tmp_path = Path(tmp_dir) - for object_id, scene in object_scenes: - mesh = _scene_to_mesh(scene, trimesh=trimesh) - mesh_bounds = np.asarray(mesh.bounds, dtype=np.float64) - mesh_z_height = max(float(mesh_bounds[1][2] - mesh_bounds[0][2]), 0.0) - bottom_to_xy_plane_transform = _aabb_bottom_to_xy_plane_transform( - mesh_bounds - ) - normalized_scene = _copy_scene_with_transform( - scene, - bottom_to_xy_plane_transform, - ) - normalized_output_scene = _copy_scene_with_transform( - normalized_scene, - output_axis_transform, - ) - pre_gravity_path = tmp_path / f"{object_id}_pre_gravity.glb" - normalized_output_scene.export(pre_gravity_path) - gravity_initial_height = mesh_z_height * 0.1 - - gravity_status = "ok" - gravity_transform = np.eye(4, dtype=np.float64) - gravity_reason = "" - try: - gravity_result = sim.run_gravity_simulation( - GravityDropRequest( - glb_path=pre_gravity_path, - max_convex_hull_num=32, - initial_height=gravity_initial_height, - ) - ) - gravity_transform = _matrix_from_json( - gravity_result.final_pose, - name=f"{object_id}.gravity_final_pose", - ) - except Exception: - gravity_status = "failed" - gravity_reason = traceback.format_exc() - - settled_origin_scene = _copy_scene_with_transform( - normalized_scene, - gravity_transform, - ) - settled_mesh = _scene_to_mesh(settled_origin_scene, trimesh=trimesh) - settled_bounds = np.asarray(settled_mesh.bounds, dtype=np.float64) - settled_xy_center = _xy_aabb_center(settled_bounds) - settled_xy_size = _xy_aabb_size(settled_bounds) - settled_entries.append( - { - "id": object_id, - "scene": scene, - "bottom_to_xy_plane_transform": bottom_to_xy_plane_transform, - "mesh_z_height": mesh_z_height, - "gravity_initial_height": gravity_initial_height, - "gravity_transform": gravity_transform, - "settled_bounds": settled_bounds, - "settled_xy_center": settled_xy_center, - "settled_xy_size": settled_xy_size, - "gravity_status": gravity_status, - "gravity_reason": gravity_reason, - } - ) - - layout_result = _optimize_xy_aabb_footprint_layout( - object_ids=[str(entry["id"]) for entry in settled_entries], - xy_sizes={ - str(entry["id"]): np.asarray(entry["settled_xy_size"], dtype=np.float64) - for entry in settled_entries - }, - current_centers={ - str(entry["id"]): _xy_aabb_center( - _scene_to_mesh(entry["scene"], trimesh=trimesh).bounds - ) - for entry in settled_entries - }, - ) - target_centers = layout_result["centers"] - - packed_object_scenes: list[tuple[str, Any]] = [] - object_layout_transforms: dict[str, np.ndarray] = {} - for entry in settled_entries: - object_id = str(entry["id"]) - settled_bounds = np.asarray(entry["settled_bounds"], dtype=np.float64) - target_xy = target_centers[object_id] - placement_transform = np.eye(4, dtype=np.float64) - placement_transform[:3, 3] = [ - float(target_xy[0] - entry["settled_xy_center"][0]), - float(target_xy[1] - entry["settled_xy_center"][1]), - -float(settled_bounds[0][2]), - ] - object_transform = ( - placement_transform - @ entry["gravity_transform"] - @ entry["bottom_to_xy_plane_transform"] - ) - packed_scene = _copy_scene_with_transform(entry["scene"], object_transform) - packed_object_scenes.append((object_id, packed_scene)) - object_layout_transforms[object_id] = object_transform - - packed_bounds = np.asarray( - _scene_to_mesh(packed_scene, trimesh=trimesh).bounds, - dtype=np.float64, - ) - footprint_items.append( - { - "id": object_id, - "gravity_status": entry["gravity_status"], - "gravity_reason": entry["gravity_reason"], - "bottom_to_xy_plane_transform": entry[ - "bottom_to_xy_plane_transform" - ].tolist(), - "mesh_z_height": entry["mesh_z_height"], - "gravity_initial_height": entry["gravity_initial_height"], - "gravity_transform": entry["gravity_transform"].tolist(), - "placement_transform": placement_transform.tolist(), - "object_layout_transform": object_transform.tolist(), - "settled_xy_size": entry["settled_xy_size"].tolist(), - "target_xy_center": target_xy.tolist(), - "packed_bounds": packed_bounds.tolist(), - } - ) - - manifest = { - "status": "ok", - "method": "per_object_gravity_then_geometry_knn_2d_aabb_relaxation", - "output_dir": relative_path(str(output_dir), output_root), - "internal_up_axis": [0.0, 0.0, 1.0], - "gravity_glb_up_axis": [0.0, 1.0, 0.0], - "internal_to_gravity_glb_transform": output_axis_transform.tolist(), - "gravity_glb_to_internal_transform": output_to_internal_transform.tolist(), - "layout_optimization": layout_result["metadata"], - "items": footprint_items, - } - return { - "object_scenes": packed_object_scenes, - "object_layout_transforms": object_layout_transforms, - "manifest": manifest, - } - - - -def _optimize_xy_aabb_footprint_layout( - *, - object_ids: list[str], - xy_sizes: dict[str, np.ndarray], - current_centers: dict[str, np.ndarray], - padding_ratio: float = 0.08, -) -> dict[str, Any]: - if not object_ids: - return { - "centers": {}, - "metadata": { - "method": "geometry_knn_2d_aabb_relaxation", - "iterations": 0, - "confidence_score": 1.0, - }, - } - - max_extent = max( - float(max(xy_sizes[object_id][0], xy_sizes[object_id][1])) - for object_id in object_ids - ) - padding = max(max_extent * padding_ratio, 1e-3) - max_iterations = 300 - overlap_strength = 1.0 - neighbor_strength = 0.04 - compactness_strength = 0.01 - target_expansion_ratio = 1.2 - knn_k = min(3, max(len(object_ids) - 1, 0)) - centers = { - object_id: np.asarray( - current_centers.get(object_id, np.zeros(2, dtype=np.float64)), - dtype=np.float64, - ).copy() - for object_id in object_ids - } - centers = _center_xy_aabb_layout( - centers=centers, - xy_sizes=xy_sizes, - ) - initial_centers = { - object_id: center.copy() - for object_id, center in centers.items() - } - initial_union_bounds = _xy_union_bounds( - centers=initial_centers, - xy_sizes=xy_sizes, - ) - neighbor_edges = _knn_neighbor_edges( - centers=initial_centers, - k=knn_k, - ) - - iterations = 0 - for iteration in range(max_iterations): - iterations = iteration + 1 - max_delta = 0.0 - - for i, object_id in enumerate(object_ids): - for other_id in object_ids[i + 1 :]: - overlap = _xy_aabb_overlap( - center_a=centers[object_id], - size_a=xy_sizes[object_id], - center_b=centers[other_id], - size_b=xy_sizes[other_id], - padding=padding, - ) - if overlap is None: - continue - overlap_x, overlap_y = overlap - if overlap_x <= overlap_y: - axis = 0 - sign = ( - -1.0 - if centers[object_id][0] <= centers[other_id][0] - else 1.0 - ) - amount = overlap_x - else: - axis = 1 - sign = ( - -1.0 - if centers[object_id][1] <= centers[other_id][1] - else 1.0 - ) - amount = overlap_y - shift = 0.5 * (amount + 1e-6) * overlap_strength - centers[object_id][axis] += sign * shift - centers[other_id][axis] -= sign * shift - max_delta = max(max_delta, shift) - - for edge in neighbor_edges: - object_id = edge["object"] - neighbor_id = edge["neighbor"] - initial_delta = np.asarray(edge["initial_delta"], dtype=np.float64) - error = (centers[object_id] - centers[neighbor_id]) - initial_delta - correction = 0.5 * neighbor_strength * error - centers[object_id] -= correction - centers[neighbor_id] += correction - max_delta = max(max_delta, float(np.linalg.norm(correction))) - - max_delta = max( - max_delta, - _apply_compactness_pull( - centers=centers, - xy_sizes=xy_sizes, - initial_union_bounds=initial_union_bounds, - target_expansion_ratio=target_expansion_ratio, - strength=compactness_strength, - ), - ) - - centers = _center_xy_aabb_layout( - centers=centers, - xy_sizes=xy_sizes, - ) - if iteration >= 20 and max_delta < 1e-5: - break - - diagnostics = _footprint_layout_diagnostics( - object_ids=object_ids, - centers=centers, - initial_centers=initial_centers, - xy_sizes=xy_sizes, - padding=padding, - initial_union_bounds=initial_union_bounds, - ) - metadata = { - "method": "geometry_knn_2d_aabb_relaxation", - "relation_usage": "disabled", - "iterations": iterations, - "padding": padding, - "padding_ratio": padding_ratio, - "max_iterations": max_iterations, - "overlap_strength": overlap_strength, - "neighbor_strength": neighbor_strength, - "compactness_strength": compactness_strength, - "target_expansion_ratio": target_expansion_ratio, - "knn_k": knn_k, - "neighbor_edges": neighbor_edges, - "final_centers": { - object_id: centers[object_id].tolist() - for object_id in object_ids - }, - **diagnostics, - } - return {"centers": centers, "metadata": metadata} - - - -def _knn_neighbor_edges( - *, - centers: dict[str, np.ndarray], - k: int, -) -> list[dict[str, Any]]: - if k <= 0 or len(centers) < 2: - return [] - object_ids = sorted(centers) - edges: list[dict[str, Any]] = [] - seen: set[tuple[str, str]] = set() - for object_id in object_ids: - distances = [] - for other_id in object_ids: - if other_id == object_id: - continue - distance = float(np.linalg.norm(centers[object_id] - centers[other_id])) - distances.append((distance, other_id)) - for _, neighbor_id in sorted(distances)[:k]: - edge_key = tuple(sorted((object_id, neighbor_id))) - if edge_key in seen: - continue - seen.add(edge_key) - edges.append( - { - "object": object_id, - "neighbor": neighbor_id, - "initial_delta": ( - centers[object_id] - centers[neighbor_id] - ).tolist(), - } - ) - return edges - - - -def _apply_compactness_pull( - *, - centers: dict[str, np.ndarray], - xy_sizes: dict[str, np.ndarray], - initial_union_bounds: np.ndarray, - target_expansion_ratio: float, - strength: float, -) -> float: - current_bounds = _xy_union_bounds(centers=centers, xy_sizes=xy_sizes) - expansion_ratio = _xy_union_area(current_bounds) / max( - _xy_union_area(initial_union_bounds), - 1.0e-12, - ) - if expansion_ratio <= target_expansion_ratio: - return 0.0 - excess = min(expansion_ratio / target_expansion_ratio - 1.0, 1.0) - union_center = 0.5 * (current_bounds[0] + current_bounds[1]) - factor = strength * excess - max_delta = 0.0 - for object_id, center in centers.items(): - delta = factor * (union_center - center) - centers[object_id] = center + delta - max_delta = max(max_delta, float(np.linalg.norm(delta))) - return max_delta - - - -def _footprint_layout_diagnostics( - *, - object_ids: list[str], - centers: dict[str, np.ndarray], - initial_centers: dict[str, np.ndarray], - xy_sizes: dict[str, np.ndarray], - padding: float, - initial_union_bounds: np.ndarray, -) -> dict[str, Any]: - remaining_overlaps = _remaining_xy_overlaps( - object_ids=object_ids, - centers=centers, - xy_sizes=xy_sizes, - padding=padding, - ) - displacements = [ - float(np.linalg.norm(centers[object_id] - initial_centers[object_id])) - for object_id in object_ids - ] - current_union_bounds = _xy_union_bounds(centers=centers, xy_sizes=xy_sizes) - expansion_ratio = _xy_union_area(current_union_bounds) / max( - _xy_union_area(initial_union_bounds), - 1.0e-12, - ) - average_displacement = float(np.mean(displacements)) if displacements else 0.0 - max_displacement = float(np.max(displacements)) if displacements else 0.0 - confidence_score = _footprint_confidence_score( - remaining_overlap_count=len(remaining_overlaps), - average_displacement=average_displacement, - max_extent=max( - float(max(xy_sizes[object_id][0], xy_sizes[object_id][1])) - for object_id in object_ids - ) - if object_ids - else 1.0, - expansion_ratio=expansion_ratio, - ) - return { - "remaining_overlaps": remaining_overlaps, - "average_displacement": average_displacement, - "max_displacement": max_displacement, - "union_aabb_expansion_ratio": expansion_ratio, - "confidence_score": confidence_score, - } - - - -def _remaining_xy_overlaps( - *, - object_ids: list[str], - centers: dict[str, np.ndarray], - xy_sizes: dict[str, np.ndarray], - padding: float, -) -> list[dict[str, Any]]: - overlaps: list[dict[str, Any]] = [] - for index, object_id in enumerate(object_ids): - for other_id in object_ids[index + 1 :]: - overlap = _xy_aabb_overlap( - center_a=centers[object_id], - size_a=xy_sizes[object_id], - center_b=centers[other_id], - size_b=xy_sizes[other_id], - padding=padding, - ) - if overlap is None: - continue - overlaps.append( - { - "object": object_id, - "other": other_id, - "overlap_x": overlap[0], - "overlap_y": overlap[1], - } - ) - return overlaps - - - -def _footprint_confidence_score( - *, - remaining_overlap_count: int, - average_displacement: float, - max_extent: float, - expansion_ratio: float, -) -> float: - displacement_scale = max(max_extent, 1.0e-6) - overlap_penalty = min(0.35 * remaining_overlap_count, 0.7) - displacement_penalty = min(0.1 * average_displacement / displacement_scale, 0.2) - expansion_penalty = min(max(expansion_ratio - 1.2, 0.0) * 0.25, 0.2) - return float( - np.clip( - 1.0 - - overlap_penalty - - displacement_penalty - - expansion_penalty, - 0.0, - 1.0, - ) - ) - - - -def _center_xy_aabb_layout( - *, - centers: dict[str, np.ndarray], - xy_sizes: dict[str, np.ndarray], -) -> dict[str, np.ndarray]: - if not centers: - return centers - bounds_min = [] - bounds_max = [] - for object_id, center in centers.items(): - half_size = 0.5 * np.asarray(xy_sizes[object_id], dtype=np.float64) - bounds_min.append(center - half_size) - bounds_max.append(center + half_size) - clutter_center = 0.5 * ( - np.vstack(bounds_min).min(axis=0) - + np.vstack(bounds_max).max(axis=0) - ) - return { - object_id: np.asarray(center, dtype=np.float64) - clutter_center - for object_id, center in centers.items() - } - - - -def _xy_union_bounds( - *, - centers: dict[str, np.ndarray], - xy_sizes: dict[str, np.ndarray], -) -> np.ndarray: - if not centers: - return np.zeros((2, 2), dtype=np.float64) - bounds_min = [] - bounds_max = [] - for object_id, center in centers.items(): - half_size = 0.5 * np.asarray(xy_sizes[object_id], dtype=np.float64) - bounds_min.append(np.asarray(center, dtype=np.float64) - half_size) - bounds_max.append(np.asarray(center, dtype=np.float64) + half_size) - return np.vstack( - [ - np.vstack(bounds_min).min(axis=0), - np.vstack(bounds_max).max(axis=0), - ] - ) - - - -def _xy_union_area(bounds: np.ndarray) -> float: - bounds = np.asarray(bounds, dtype=np.float64) - size = np.maximum(bounds[1] - bounds[0], 1.0e-9) - return float(size[0] * size[1]) - - - -def _xy_aabb_overlap( - *, - center_a: np.ndarray, - size_a: np.ndarray, - center_b: np.ndarray, - size_b: np.ndarray, - padding: float, -) -> tuple[float, float] | None: - half_a = 0.5 * np.asarray(size_a, dtype=np.float64) - half_b = 0.5 * np.asarray(size_b, dtype=np.float64) - delta = np.abs( - np.asarray(center_b, dtype=np.float64) - - np.asarray(center_a, dtype=np.float64) - ) - overlap = half_a + half_b + padding - delta - if float(overlap[0]) <= 0.0 or float(overlap[1]) <= 0.0: - return None - return float(overlap[0]), float(overlap[1]) diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/managers/simready_manager/__init__.py b/embodichain/gen_sim/prompt2scene/agent_tools/managers/simready_manager/__init__.py index 12ebfd690..b85c8749f 100644 --- a/embodichain/gen_sim/prompt2scene/agent_tools/managers/simready_manager/__init__.py +++ b/embodichain/gen_sim/prompt2scene/agent_tools/managers/simready_manager/__init__.py @@ -17,19 +17,40 @@ from __future__ import annotations from embodichain.gen_sim.prompt2scene.agent_tools.managers.simready_manager.manager import ( + METRIC_SCALE_ENABLED, SimreadyManager, ) from embodichain.gen_sim.prompt2scene.agent_tools.managers.simready_manager.schemas import ( + EstimateMetricScalesRequest, + EstimateMetricScalesResult, + GlobalMetricScaleRequest, MakeAssetSimreadyRequest, MakeAssetSimreadyResult, MakeTableSimreadyRequest, MakeTableSimreadyResult, + MetricScaleObjectInput, ) +from embodichain.gen_sim.prompt2scene.prompts.builders import ( + build_image_metric_scale_messages, +) +from embodichain.gen_sim.prompt2scene.prompts.schemas import ( + IMAGE_METRIC_SCALE_JSON_SCHEMA, +) + +MetricScaleManager = SimreadyManager __all__ = [ + "EstimateMetricScalesRequest", + "EstimateMetricScalesResult", + "GlobalMetricScaleRequest", + "IMAGE_METRIC_SCALE_JSON_SCHEMA", "MakeAssetSimreadyRequest", "MakeAssetSimreadyResult", "MakeTableSimreadyRequest", "MakeTableSimreadyResult", + "METRIC_SCALE_ENABLED", + "MetricScaleManager", + "MetricScaleObjectInput", "SimreadyManager", + "build_image_metric_scale_messages", ] diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/managers/simready_manager/manager.py b/embodichain/gen_sim/prompt2scene/agent_tools/managers/simready_manager/manager.py index 6f92e1f84..ae46a552f 100644 --- a/embodichain/gen_sim/prompt2scene/agent_tools/managers/simready_manager/manager.py +++ b/embodichain/gen_sim/prompt2scene/agent_tools/managers/simready_manager/manager.py @@ -55,6 +55,23 @@ ) +METRIC_SCALE_ENABLED = True + +from .utils import ( + _as_transform, + _axis_angle_rotation, + _axis_conversion_transform, + _center_aabb_bottom_xy_at_origin, + _center_aabb_bottom_xy_at_origin_transform, + _normalize, + _orthogonal_axis, + _place_above_plane_transform, + _request_axis, + _rotation_between_vectors, + _scale_transform, + _translation_transform, +) + class SimreadyManager: """Prepare generated GLB assets for simulation placement.""" @@ -294,103 +311,334 @@ def make_table_simready( output_path=output_path, transform_matrix=raw_to_simready.tolist(), ) + @staticmethod + def estimate_metric_scales(request): + from embodichain.gen_sim.prompt2scene.agent_tools.managers.simready_manager.schemas import ( + EstimateMetricScalesRequest, + EstimateMetricScalesResult, + ) + from embodichain.gen_sim.prompt2scene.llms.llm_output import ( + call_structured_json_model_step, + ) + from embodichain.gen_sim.prompt2scene.utils.io import write_json + + object_payload = SimreadyManager.build_object_payload(request.objects) + raw_model_output_path = ( + request.raw_output_path.expanduser().resolve() + if request.raw_output_path is not None + else None + ) + raw_model_output = call_structured_json_model_step( + llm=request.llm, + schema=request.schema, + messages=request.messages, + context=request.context, + attempt_count=0, + raw_output_writer=( + (lambda payload: write_json(raw_model_output_path, payload)) + if raw_model_output_path is not None + else None + ), + ) + object_scales = SimreadyManager.apply_model_output( + object_payload=object_payload, + raw_model_output=raw_model_output, + method=request.method, + ) + return EstimateMetricScalesResult( + status="ok", + object_scales=object_scales, + object_payload=object_payload, + raw_model_output=raw_model_output, + ) + + + @staticmethod + def build_object_payload(objects): + from embodichain.gen_sim.prompt2scene.agent_tools.managers.geometry_manager import ( + GeometryManager, + LoadMeshRequest, + ) + from embodichain.gen_sim.prompt2scene.agent_tools.managers.simready_manager.schemas import ( + MetricScaleObjectInput, + ) + + geom = GeometryManager() + payload = [] + for obj in objects: + mesh = geom.load_mesh(LoadMeshRequest(mesh_path=obj.mesh_path)).mesh + normalized_bbox_size_m = GeometryManager.mesh_aabb_size(mesh) + payload.append({ + "object_id": obj.object_id, + "object_name": obj.object_name, + "object_description": obj.object_description, + "normalized_bbox_size_m": normalized_bbox_size_m.tolist(), + "normalized_bbox_ratio": GeometryManager.bbox_ratio( + normalized_bbox_size_m + ).tolist(), + }) + return payload + + + @staticmethod + def object_prompt_payload(objects): + return [ + { + "object_id": obj.object_id, + "object_name": obj.object_name, + "object_description": obj.object_description, + } + for obj in objects + ] + + + @staticmethod + def apply_model_output(*, object_payload, raw_model_output, method): + import numpy as np + + from embodichain.gen_sim.prompt2scene.agent_tools.managers.geometry_manager import ( + GeometryManager, + ) + + model_by_id = { + str(item.get("object_id", "")): item + for item in raw_model_output.get("object_scales", []) + if isinstance(item, dict) + } + estimates = [] + for p in object_payload: + oid = str(p.get("object_id", "")) + model_item = model_by_id.get(oid) + if model_item is None: + estimates.append(SimreadyManager.failure( + object_id=oid, + reason="missing_object_scale_from_model", + method=method, + )) + continue + estimates.append(SimreadyManager.select_candidate( + object_id=oid, + object_name=str(p.get("object_name", "")), + object_description=str(p.get("object_description", "")), + bbox_dims_cm=model_item.get("bbox_dims_cm", []), + confidence=float(model_item.get("confidence", 0.0)), + reason=str(model_item.get("reason", "")), + normalized_bbox_size_m=np.asarray( + p["normalized_bbox_size_m"], dtype=np.float64 + ), + method=method, + )) + return estimates + + + @staticmethod + def apply_to_objects(*, objects, object_scales): + scale_by_id = {str(item.get("object_id", "")): item for item in object_scales} + for obj in objects: + oid = str(obj.get("id", "")) + if oid in scale_by_id: + obj["metric_scale"] = scale_by_id[oid] + + + @staticmethod + def select_candidate(*, object_id, object_name, object_description, bbox_dims_cm, confidence, reason, normalized_bbox_size_m, method): + import numpy as np + from embodichain.gen_sim.prompt2scene.agent_tools.managers.geometry_manager import ( + GeometryManager, + ) + try: + selected = SimreadyManager.compute_from_bbox_dims( + bbox_dims_cm=bbox_dims_cm, + confidence=confidence, + reason=reason, + normalized_bbox_size_m=normalized_bbox_size_m, + ) + except (TypeError, ValueError): + return SimreadyManager.failure( + object_id=object_id, + reason="invalid_bbox_dims_cm", + method=method, + ) + nbs_cm = np.asarray(normalized_bbox_size_m, dtype=np.float64) * 100.0 + return { + "status": "ok", + "method": method, + "object_id": object_id, + "object_name": object_name, + "object_description": object_description, + "normalized_bbox_size_m": normalized_bbox_size_m.tolist(), + "normalized_bbox_size_cm": nbs_cm.tolist(), + "normalized_bbox_ratio": GeometryManager.bbox_ratio( + normalized_bbox_size_m + ).tolist(), + "bbox_dims_cm": selected["bbox_dims_cm"], + "axis_match": selected["axis_match"], + "scale_factor": selected["scale_factor"], + "confidence": selected["confidence"], + "reason": selected["reason"], + "unit_note": "scale_factor is not baked into this GLB.", + } + + + @staticmethod + def compute_from_bbox_dims(*, bbox_dims_cm, confidence, reason, normalized_bbox_size_m): + import numpy as np + from embodichain.gen_sim.prompt2scene.agent_tools.managers.geometry_manager import ( + GeometryManager, + ) + dims_cm = np.asarray([float(v) for v in bbox_dims_cm], dtype=np.float64) + if dims_cm.shape != (3,) or np.any(dims_cm <= 0.0): + raise ValueError("bbox_dims_cm must contain three positive values.") + nbs_cm = np.asarray(normalized_bbox_size_m, dtype=np.float64) * 100.0 + axis_match = GeometryManager.best_axis_bbox_scale_match( + source_size_cm=nbs_cm, + target_size_cm=dims_cm, + ) + return { + "bbox_dims_cm": dims_cm.tolist(), + "axis_match": axis_match, + "scale_factor": float(axis_match["scale_factor"]), + "confidence": confidence, + "reason": reason, + } + + + @staticmethod + def failure(*, object_id, reason, method): + return { + "status": "failed", + "method": method, + "object_id": object_id, + "scale_factor": 1.0, + "reason": reason, + } + + + @staticmethod + def set_for_all_objects(*, objects, status, reason, method): + for obj in objects: + obj["metric_scale"] = { + "status": status, + "method": method, + "object_id": str(obj.get("id", "")), + "scale_factor": 1.0, + "reason": reason, + } + + + @staticmethod + def compute_global_from_object_scenes(request): + import numpy as np + from embodichain.gen_sim.prompt2scene.agent_tools.managers.geometry_manager import ( + GeometryManager, + ) + from embodichain.gen_sim.prompt2scene.agent_tools.managers.simready_manager.schemas import ( + GlobalMetricScaleRequest, + ) + if not METRIC_SCALE_ENABLED: + return { + "status": "disabled", + "method": "metric_scale_disabled", + "scale_factor": 1.0, + "object_count": len(request.objects), + "used_count": 0, + "skipped_count": len(request.objects), + "used": [], + "skipped": [ + {"id": str(item.get("id", "")), "reason": "metric_scale_disabled"} + for item in request.objects + ], + "unit_note": "Metric scale is disabled; GLBs keep simready size.", + } + + used = [] + skipped = [] + object_by_id = {str(item.get("id", "")): item for item in request.objects} + for object_id, scene in request.object_scenes: + item = object_by_id.get(object_id) + if item is None: + skipped.append({"id": object_id, "reason": "missing_object_record"}) + continue + ms = item.get("metric_scale") + if not isinstance(ms, dict): + skipped.append({"id": object_id, "reason": "missing_metric_scale"}) + continue + if ms.get("status") != "ok": + skipped.append({"id": object_id, "reason": str(ms.get("status") or "not_ok")}) + continue + sf = float(ms.get("scale_factor", 1.0)) + if not np.isfinite(sf) or sf <= 0.0: + skipped.append({"id": object_id, "reason": "invalid_simready_scale_factor"}) + continue + try: + srs = np.asarray([float(v) for v in ms.get("normalized_bbox_size_m", [])], dtype=np.float64) + except (TypeError, ValueError): + skipped.append({"id": object_id, "reason": "invalid_normalized_bbox_size_m"}) + continue + if srs.shape != (3,) or np.any(srs <= 0.0): + skipped.append({"id": object_id, "reason": "invalid_normalized_bbox_size_m"}) + continue + cb = np.asarray(GeometryManager.scene_to_mesh(scene).bounds) + cs = cb[1] - cb[0] + if cs.shape != (3,) or np.any(cs <= 0.0): + skipped.append({"id": object_id, "reason": "invalid_current_scene_aabb"}) + continue + geo_ratio = np.sort(cs) / np.sort(srs) + geo_scale = float(np.median(geo_ratio)) + if not np.isfinite(geo_scale) or geo_scale <= 0.0: + skipped.append({"id": object_id, "reason": "non_positive_geo_scale"}) + continue + effective = sf / geo_scale + if not np.isfinite(effective) or effective <= 0.0: + skipped.append({"id": object_id, "reason": "non_positive_effective_scale"}) + continue + used.append({ + "id": object_id, + "effective_scale": effective, + "scale_factor_simready": sf, + "geo_scale": geo_scale, + "simready_bbox_size_m": srs.tolist(), + "simready_bbox_size_cm": (srs * 100.0).tolist(), + "current_scene_bbox_size_m": cs.tolist(), + "current_scene_bbox_size_cm": (cs * 100.0).tolist(), + "target_bbox_dims_cm": ms.get("bbox_dims_cm"), + "confidence": ms.get("confidence"), + }) + + if not used: + return { + "status": "fallback", + "method": "simready_reference_geo_ratio_mean_with_clamp", + "scale_factor": 1.0, + "raw_scale_factor": 1.0, + "was_clamped": False, + "clamp": {"min": request.min_scale, "max": request.max_scale}, + "object_count": len(request.objects), + "used_count": 0, + "skipped_count": len(skipped), + "used": [], + "skipped": skipped, + "unit_note": "No valid metric scale available.", + } + + raw = float(np.mean([u["effective_scale"] for u in used])) + sf = float(np.clip(raw, request.min_scale, request.max_scale)) + return { + "status": "ok", + "method": "simready_reference_geo_ratio_mean_with_clamp", + "scale_factor": sf, + "raw_scale_factor": raw, + "was_clamped": bool(sf != raw), + "clamp": {"min": request.min_scale, "max": request.max_scale}, + "object_count": len(request.objects), + "used_count": len(used), + "skipped_count": len(skipped), + "used": used, + "skipped": skipped, + "unit_note": ( + f"Global scale via per-object metric scale / geo ratio, " + f"clamped to [{request.min_scale:.2f}, {request.max_scale:.2f}]." + ), + } -def _request_axis(value: list[float] | None, default: tuple[float, float, float]) -> list[float]: - if value is not None: - return list(value) - return list(default) - - -def _center_aabb_bottom_xy_at_origin(mesh: Any) -> Any: - bounds = mesh.bounds - bottom_center_x = (float(bounds[0][0]) + float(bounds[1][0])) * 0.5 - bottom_center_y = (float(bounds[0][1]) + float(bounds[1][1])) * 0.5 - centered = mesh.copy() - centered.apply_translation([-bottom_center_x, -bottom_center_y, 0.0]) - return centered - - -def _axis_conversion_transform(source_axis: list[float], target_axis: list[float]) -> np.ndarray: - source = _normalize(np.asarray(source_axis, dtype=np.float64)) - target = _normalize(np.asarray(target_axis, dtype=np.float64)) - return _rotation_between_vectors(source, target) - - -def _place_above_plane_transform(mesh: Any, clearance: float) -> np.ndarray: - min_z = float(mesh.bounds[0][2]) - return _translation_transform(np.array([0.0, 0.0, clearance - min_z])) - - -def _center_aabb_bottom_xy_at_origin_transform(mesh: Any) -> np.ndarray: - bounds = mesh.bounds - bottom_center_x = (float(bounds[0][0]) + float(bounds[1][0])) * 0.5 - bottom_center_y = (float(bounds[0][1]) + float(bounds[1][1])) * 0.5 - return _translation_transform(np.array([-bottom_center_x, -bottom_center_y, 0.0])) - - -def _translation_transform(translation: np.ndarray) -> np.ndarray: - transform = np.eye(4, dtype=np.float64) - transform[:3, 3] = translation - return transform - - -def _scale_transform(scale: float) -> np.ndarray: - transform = np.eye(4, dtype=np.float64) - transform[:3, :3] *= float(scale) - return transform - - -def _as_transform(value: Any) -> np.ndarray: - transform = np.asarray(value, dtype=np.float64) - if transform.shape != (4, 4): - raise ValueError("Expected a 4x4 transform matrix.") - return transform - - -def _rotation_between_vectors(source: np.ndarray, target: np.ndarray) -> np.ndarray: - source = _normalize(source) - target = _normalize(target) - dot = float(np.clip(np.dot(source, target), -1.0, 1.0)) - transform = np.eye(4, dtype=np.float64) - if dot > 1.0 - 1e-8: - return transform - if dot < -1.0 + 1e-8: - axis = _orthogonal_axis(source) - rotation = _axis_angle_rotation(axis, np.pi) - else: - axis = _normalize(np.cross(source, target)) - angle = float(np.arccos(dot)) - rotation = _axis_angle_rotation(axis, angle) - transform[:3, :3] = rotation - return transform - - -def _axis_angle_rotation(axis: np.ndarray, angle: float) -> np.ndarray: - axis = _normalize(axis) - x, y, z = axis - c = float(np.cos(angle)) - s = float(np.sin(angle)) - one_c = 1.0 - c - return np.array( - [ - [c + x * x * one_c, x * y * one_c - z * s, x * z * one_c + y * s], - [y * x * one_c + z * s, c + y * y * one_c, y * z * one_c - x * s], - [z * x * one_c - y * s, z * y * one_c + x * s, c + z * z * one_c], - ], - dtype=np.float64, - ) - - -def _orthogonal_axis(vector: np.ndarray) -> np.ndarray: - axis = np.array([1.0, 0.0, 0.0], dtype=np.float64) - if abs(float(np.dot(vector, axis))) > 0.9: - axis = np.array([0.0, 1.0, 0.0], dtype=np.float64) - return _normalize(np.cross(vector, axis)) - - -def _normalize(vector: np.ndarray) -> np.ndarray: - norm = float(np.linalg.norm(vector)) - if norm == 0.0: - return vector - return vector / norm diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/managers/simready_manager/schemas.py b/embodichain/gen_sim/prompt2scene/agent_tools/managers/simready_manager/schemas.py index 86ae22b0a..bc105c9df 100644 --- a/embodichain/gen_sim/prompt2scene/agent_tools/managers/simready_manager/schemas.py +++ b/embodichain/gen_sim/prompt2scene/agent_tools/managers/simready_manager/schemas.py @@ -18,12 +18,22 @@ from dataclasses import dataclass from pathlib import Path +from typing import Any + +__all__ = [ + "EstimateMetricScalesRequest", + "EstimateMetricScalesResult", + "GlobalMetricScaleRequest", + "MakeAssetSimreadyRequest", + "MakeAssetSimreadyResult", + "MakeTableSimreadyRequest", + "MakeTableSimreadyResult", + "MetricScaleObjectInput", +] @dataclass(frozen=True) class MakeAssetSimreadyRequest: - """Request to prepare a general asset GLB for simulation placement.""" - input_path: Path output_path: Path input_up_axis: list[float] | None = None @@ -33,16 +43,12 @@ class MakeAssetSimreadyRequest: @dataclass(frozen=True) class MakeAssetSimreadyResult: - """Result of making an asset simulation-ready.""" - output_path: Path transform_matrix: list[list[float]] @dataclass(frozen=True) class MakeTableSimreadyRequest: - """Request to prepare a generated table GLB for simulation placement.""" - input_path: Path output_path: Path input_up_axis: list[float] | None = None @@ -52,7 +58,42 @@ class MakeTableSimreadyRequest: @dataclass(frozen=True) class MakeTableSimreadyResult: - """Result of making a table simulation-ready.""" - output_path: Path transform_matrix: list[list[float]] + + +@dataclass(frozen=True) +class MetricScaleObjectInput: + object_id: str + object_name: str + object_description: str + mesh_path: Path + + +@dataclass(frozen=True) +class EstimateMetricScalesRequest: + objects: list[MetricScaleObjectInput] + messages: list[dict[str, Any]] + schema: dict[str, Any] + llm: Any + context: str + method: str + step_name: str = "metric_scale" + raw_output_path: Path | None = None + + +@dataclass(frozen=True) +class EstimateMetricScalesResult: + status: str + object_scales: list[dict[str, Any]] + object_payload: list[dict[str, Any]] + raw_model_output: dict[str, Any] | None = None + reason: str = "" + + +@dataclass(frozen=True) +class GlobalMetricScaleRequest: + objects: list[dict[str, Any]] + object_scenes: list[tuple[str, Any]] + min_scale: float = 0.10 + max_scale: float = 10.00 diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/managers/simready_manager/utils.py b/embodichain/gen_sim/prompt2scene/agent_tools/managers/simready_manager/utils.py new file mode 100644 index 000000000..1a52dd13d --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/managers/simready_manager/utils.py @@ -0,0 +1,136 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from typing import Any + +import numpy as np + +__all__ = [ + "_as_transform", + "_axis_angle_rotation", + "_axis_conversion_transform", + "_center_aabb_bottom_xy_at_origin", + "_center_aabb_bottom_xy_at_origin_transform", + "_normalize", + "_orthogonal_axis", + "_place_above_plane_transform", + "_request_axis", + "_rotation_between_vectors", + "_scale_transform", + "_translation_transform", +] + + +def _request_axis(value: list[float] | None, default: tuple[float, float, float]) -> list[float]: + if value is not None: + return list(value) + return list(default) + + +def _center_aabb_bottom_xy_at_origin(mesh: Any) -> Any: + bounds = mesh.bounds + bottom_center_x = (float(bounds[0][0]) + float(bounds[1][0])) * 0.5 + bottom_center_y = (float(bounds[0][1]) + float(bounds[1][1])) * 0.5 + centered = mesh.copy() + centered.apply_translation([-bottom_center_x, -bottom_center_y, 0.0]) + return centered + + +def _axis_conversion_transform(source_axis: list[float], target_axis: list[float]) -> np.ndarray: + source = _normalize(np.asarray(source_axis, dtype=np.float64)) + target = _normalize(np.asarray(target_axis, dtype=np.float64)) + return _rotation_between_vectors(source, target) + + +def _place_above_plane_transform(mesh: Any, clearance: float) -> np.ndarray: + min_z = float(mesh.bounds[0][2]) + return _translation_transform(np.array([0.0, 0.0, clearance - min_z])) + + +def _center_aabb_bottom_xy_at_origin_transform(mesh: Any) -> np.ndarray: + bounds = mesh.bounds + bottom_center_x = (float(bounds[0][0]) + float(bounds[1][0])) * 0.5 + bottom_center_y = (float(bounds[0][1]) + float(bounds[1][1])) * 0.5 + return _translation_transform(np.array([-bottom_center_x, -bottom_center_y, 0.0])) + + +def _translation_transform(translation: np.ndarray) -> np.ndarray: + transform = np.eye(4, dtype=np.float64) + transform[:3, 3] = translation + return transform + + +def _scale_transform(scale: float) -> np.ndarray: + transform = np.eye(4, dtype=np.float64) + transform[0, 0] = scale + transform[1, 1] = scale + transform[2, 2] = scale + return transform + + +def _as_transform(value: Any) -> np.ndarray: + if isinstance(value, np.ndarray) and value.shape == (4, 4): + return value.astype(np.float64) + raise TypeError(f"Cannot convert {type(value)} to 4x4 transform.") + + +def _rotation_between_vectors(source: np.ndarray, target: np.ndarray) -> np.ndarray: + source = _normalize(source) + target = _normalize(target) + cos_angle = np.dot(source, target) + if cos_angle > 1.0 - 1e-10: + return np.eye(4, dtype=np.float64) + if cos_angle < -1.0 + 1e-10: + axis = _orthogonal_axis(source) + return _axis_angle_rotation(axis, np.pi) + axis = np.cross(source, target) + sin_angle = np.linalg.norm(axis) + axis = axis / sin_angle + angle = np.arctan2(sin_angle, cos_angle) + return _axis_angle_rotation(axis, angle) + + +def _axis_angle_rotation(axis: np.ndarray, angle: float) -> np.ndarray: + axis = _normalize(axis) + c = np.cos(angle) + s = np.sin(angle) + t = 1.0 - c + x, y, z = axis + return np.array( + [ + [t * x * x + c, t * x * y - s * z, t * x * z + s * y, 0.0], + [t * x * y + s * z, t * y * y + c, t * y * z - s * x, 0.0], + [t * x * z - s * y, t * y * z + s * x, t * z * z + c, 0.0], + [0.0, 0.0, 0.0, 1.0], + ], + dtype=np.float64, + ) + + +def _orthogonal_axis(vector: np.ndarray) -> np.ndarray: + x, y, z = _normalize(vector) + if abs(x) < 0.9: + return np.array([1.0, 0.0, -x / (z + 1e-10)], dtype=np.float64) + return np.array([-y / (x + 1e-10), 1.0, 0.0], dtype=np.float64) + + +def _normalize(vector: np.ndarray) -> np.ndarray: + norm = np.linalg.norm(vector) + if norm < 1e-12: + raise ValueError("Cannot normalise zero-length vector.") + return vector / norm diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/managers/text_layout_manager/__init__.py b/embodichain/gen_sim/prompt2scene/agent_tools/managers/text_layout_manager/__init__.py deleted file mode 100644 index ce2215329..000000000 --- a/embodichain/gen_sim/prompt2scene/agent_tools/managers/text_layout_manager/__init__.py +++ /dev/null @@ -1,33 +0,0 @@ -# ---------------------------------------------------------------------------- -# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ---------------------------------------------------------------------------- - -from __future__ import annotations - -from embodichain.gen_sim.prompt2scene.agent_tools.managers.text_layout_manager.layout import ( - _layout_text_objects_grid, -) -from embodichain.gen_sim.prompt2scene.agent_tools.managers.text_layout_manager.optimization import ( - _optimize_text_layout_slp, -) -from embodichain.gen_sim.prompt2scene.agent_tools.managers.text_layout_manager.settle import ( - settle_text_objects_to_ground, -) - -__all__ = [ - "_layout_text_objects_grid", - "_optimize_text_layout_slp", - "settle_text_objects_to_ground", -] diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/managers/text_layout_manager/layout.py b/embodichain/gen_sim/prompt2scene/agent_tools/managers/text_layout_manager/layout.py deleted file mode 100644 index 7b94a852e..000000000 --- a/embodichain/gen_sim/prompt2scene/agent_tools/managers/text_layout_manager/layout.py +++ /dev/null @@ -1,383 +0,0 @@ -# ---------------------------------------------------------------------------- -# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ---------------------------------------------------------------------------- - - -from __future__ import annotations - -from typing import Any - -import numpy as np - -from embodichain.gen_sim.prompt2scene.agent_tools.managers.optimization_manager import ( - _center_xy_aabb_layout, -) -from embodichain.gen_sim.prompt2scene.agent_tools.managers.text_layout_manager.optimization import ( - _optimize_text_layout_slp, -) -__all__ = [ - "_layout_text_objects_grid", -] - -def _transitive_closure( - nodes: list[str], - edges: list[tuple[str, str]], -) -> list[tuple[str, str]]: - """Floyd–Warshall transitive closure over a small set of nodes.""" - if not nodes or not edges: - return list(edges) - idx = {n: i for i, n in enumerate(nodes)} - n = len(nodes) - adj = [[False] * n for _ in range(n)] - for src, dst in edges: - if src in idx and dst in idx: - adj[idx[src]][idx[dst]] = True - for k in range(n): - for i in range(n): - if adj[i][k]: - row_k = adj[k] - row_i = adj[i] - for j in range(n): - if row_k[j]: - row_i[j] = True - closed: list[tuple[str, str]] = [] - for i in range(n): - for j in range(n): - if adj[i][j]: - closed.append((nodes[i], nodes[j])) - return closed - - - -def _longest_path_ranks( - nodes: list[str], - edges: list[tuple[str, str]], -) -> dict[str, int]: - """Assign integer ranks satisfying ``(A,B)`` → rank[A] < rank[B]. - - Uses topological sort + longest-path DP. Returns a rank dict for every - node in *nodes* (default 0 for isolated nodes). - """ - ranks: dict[str, int] = {n: 0 for n in nodes} - if not edges: - return ranks - # Build adjacency and in-degree - adj: dict[str, list[str]] = {n: [] for n in nodes} - in_deg: dict[str, int] = {n: 0 for n in nodes} - present = set(nodes) - for src, dst in edges: - if src not in present or dst not in present: - continue - adj[src].append(dst) - in_deg[dst] += 1 - # Kahn topological sort - queue = [n for n in nodes if in_deg[n] == 0] - order: list[str] = [] - while queue: - u = queue.pop(0) - order.append(u) - for v in adj[u]: - in_deg[v] -= 1 - if in_deg[v] == 0: - queue.append(v) - # Longest path - for u in order: - for v in adj[u]: - if ranks[v] < ranks[u] + 1: - ranks[v] = ranks[u] + 1 - # Remaining nodes (cycles / isolated) keep rank 0 - return ranks - - - -def _layout_text_objects_grid( - *, - object_ids: list[str], - xy_sizes: dict[str, np.ndarray], - spatial_relations: list[dict[str, Any]], - table_constraints: list[dict[str, Any]] | None = None, - grid_spacing: float = 0.02, - padding_ratio: float = 0.08, -) -> dict[str, Any]: - """Lay out text-scene objects — transitive closure + longest-path ranks. - - 1. Transitive closure of left_of / front_of. - 2. Pick centre: explicit 9‑grid ʻcenterʼ, else highest-degree node. - 3. Longest-path rank assignment (left_of→X, front_of→Y). - 4. Shift 9‑grid anchors to their grid positions. - 5. Free objects auto‑wrap below. - 6. Convert ranks→XY using per‑column/row max sizes + gaps. - 7. SA point optimisation + mesh AABB collision cleanup. - """ - if not object_ids: - return { - "centers": {}, - "initial_centers": {}, - "metadata": { - "method": "transitive_closure_longest_path_with_9grid", - "iterations": 0, - }, - } - - # Parse spatial relations. - left_of_edges: list[tuple[str, str]] = [] - front_of_edges: list[tuple[str, str]] = [] - seen: set[tuple[str, str, str]] = set() - for rel in spatial_relations: - subject = str(rel.get("subject") or "") - obj = str(rel.get("object") or "") - relation = str(rel.get("relation") or "") - if not subject or not obj or subject == obj: - continue - key = (subject, relation, obj) - if key in seen: - continue - seen.add(key) - if relation == "left_of": - left_of_edges.append((subject, obj)) - elif relation == "front_of": - front_of_edges.append((subject, obj)) - - # Compute transitive closures. - left_of_closed = _transitive_closure(object_ids, left_of_edges) - front_of_closed = _transitive_closure(object_ids, front_of_edges) - - # Parse nine-grid constraints. - # −Y = front, so front row = 0, back row = 2 - _GRID_TO_RC: dict[str, tuple[int, int]] = { - "left_front": (0, 0), "center_front": (1, 0), "right_front": (2, 0), - "left_center": (0, 1), "center": (1, 1), "right_center": (2, 1), - "left_back": (0, 2), "center_back": (1, 2), "right_back": (2, 2), - "front": (1, 0), "back": (1, 2), - "left": (0, 1), "right": (2, 1), - } - grid_targets: dict[str, tuple[int, int]] = {} - for tc in (table_constraints or []): - asset = str(tc.get("asset") or "") - grid_name = str(tc.get("grid") or "").strip() - if asset in object_ids and grid_name in _GRID_TO_RC: - grid_targets[asset] = _GRID_TO_RC[grid_name] - - # Select a center object when none is explicit. - auto_center_oid: str | None = None - has_explicit_center = any( - tc.get("grid") == "center" for tc in (table_constraints or []) - ) - if not has_explicit_center: - # Degree = appearances in left_of + front_of (subject or object) - degree: dict[str, int] = {oid: 0 for oid in object_ids} - for src, dst in left_of_closed + front_of_closed: - if src in degree: - degree[src] += 1 - if dst in degree: - degree[dst] += 1 - max_deg = max(degree.values()) if degree else 0 - if max_deg > 0: - candidates = [oid for oid, d in degree.items() if d == max_deg] - # Tie-breaker: largest AABB area - centre_oid = max( - candidates, - key=lambda oid: float(xy_sizes[oid][0]) * float(xy_sizes[oid][1]), - ) - grid_targets[centre_oid] = (1, 1) # 9‑grid centre - auto_center_oid = centre_oid - - # Derive ranks from the transitive closures. - x_rank = _longest_path_ranks(object_ids, left_of_closed) - # −Y = front: A front_of B → A.y < B.y → row[A] < row[B]. - # _longest_path_ranks gives rank[src] < rank[dst]; edges are - # already (A,B) for "A front_of B", so NO reversal needed. - y_rank = _longest_path_ranks(object_ids, front_of_closed) - - # Apply nine-grid shifts. - # Pin 9‑grid objects to their target ranks; shift all connected - # objects (both upstream and downstream) to preserve topology. - if grid_targets: - # Build undirected connected-components via relation edges - all_edges = left_of_closed + front_of_closed - neighbours: dict[str, set[str]] = {oid: set() for oid in object_ids} - for src, dst in all_edges: - if src in neighbours and dst in neighbours: - neighbours[src].add(dst) - neighbours[dst].add(src) - for oid in grid_targets: - neighbours.setdefault(oid, set()) - - # For each 9‑grid object, BFS the component and shift uniformly - shifted: set[str] = set() - for oid, (target_col, target_row) in grid_targets.items(): - if oid in shifted: - continue - dx = target_col - x_rank.get(oid, 0) - dy = target_row - y_rank.get(oid, 0) - - # BFS to collect the full connected component - component: set[str] = {oid} - queue = [oid] - while queue: - u = queue.pop(0) - for v in neighbours.get(u, set()): - if v not in component: - component.add(v) - queue.append(v) - - for oid2 in component: - if oid2 not in grid_targets: # only shift non‑anchored objects - x_rank[oid2] = x_rank.get(oid2, 0) + dx - y_rank[oid2] = y_rank.get(oid2, 0) + dy - shifted.update(component) - - # Propagate row and column alignment. - # left_of A B → same row (y_rank[A] = y_rank[B]) - # front_of A B → same col (x_rank[A] = x_rank[B]) - # Priority (higher wins): 9‑grid > higher degree > larger area. - _prio = { - oid: ( - oid in grid_targets, - sum(1 for e in left_of_closed + front_of_closed if oid in e), - float(xy_sizes[oid][0]) * float(xy_sizes[oid][1]), - ) - for oid in object_ids - } - for src, dst in left_of_closed: - if _prio[src] >= _prio[dst]: - y_rank[dst] = y_rank.get(src, 0) - else: - y_rank[src] = y_rank.get(dst, 0) - for src, dst in front_of_closed: - if _prio[src] >= _prio[dst]: - x_rank[dst] = x_rank.get(src, 0) - else: - x_rank[src] = x_rank.get(dst, 0) - - # Normalise to >= 0 - min_x = min(x_rank.values()) if x_rank else 0 - min_y = min(y_rank.values()) if y_rank else 0 - for oid in object_ids: - x_rank[oid] = x_rank.get(oid, 0) - min_x - y_rank[oid] = y_rank.get(oid, 0) - min_y - - # Resolve cell collisions: spread objects sharing the same (col, row) - cell_occupants: dict[tuple[int, int], list[str]] = {} - for oid in object_ids: - cell = (x_rank[oid], y_rank[oid]) - cell_occupants.setdefault(cell, []).append(oid) - for (col, row), occupants in cell_occupants.items(): - if len(occupants) > 1: - for offset, oid in enumerate(occupants[1:], start=1): - x_rank[oid] = col + offset - - # Place unconstrained objects in wrapped rows. - constrained = set() - for src, dst in left_of_closed + front_of_closed: - constrained.update([src, dst]) - constrained.update(grid_targets) - free_objects = [oid for oid in object_ids if oid not in constrained] - - if free_objects: - free_row = max(y_rank.values()) + 1 if y_rank else 0 - # Max row width ≈ existing union width × 1.5 (at least 3 cols) - col_keys = list(x_rank.values()) - existing_cols = max(col_keys) - min(col_keys) + 1 if col_keys else 1 - max_cols_per_row = max(existing_cols, 3) - free_sorted = sorted( - free_objects, - key=lambda oid: float(xy_sizes[oid][0]), - reverse=True, - ) - col = 0 - row_offset = 0 - for oid in free_sorted: - x_rank[oid] = col - y_rank[oid] = free_row + row_offset - col += 1 - if col >= max_cols_per_row: - col = 0 - row_offset += 1 - - # Convert ranks to XY positions. - col_widths: dict[int, float] = {} - row_heights: dict[int, float] = {} - for oid in object_ids: - c = x_rank[oid] - r = y_rank[oid] - col_widths[c] = max(col_widths.get(c, 0.0), float(xy_sizes[oid][0])) - row_heights[r] = max(row_heights.get(r, 0.0), float(xy_sizes[oid][1])) - - x_cumsum: dict[int, float] = {} - cumulative = 0.0 - for c in sorted(col_widths): - x_cumsum[c] = cumulative - cumulative += col_widths[c] + grid_spacing - - y_cumsum: dict[int, float] = {} - cumulative = 0.0 - for r in sorted(row_heights): - y_cumsum[r] = cumulative - cumulative += row_heights[r] + grid_spacing - - centers: dict[str, np.ndarray] = {} - for oid in object_ids: - c = x_rank[oid] - r = y_rank[oid] - cx = x_cumsum[c] + 0.5 * float(xy_sizes[oid][0]) - cy = y_cumsum[r] + 0.5 * float(xy_sizes[oid][1]) - centers[oid] = np.array([cx, cy], dtype=np.float64) - - centers = _center_xy_aabb_layout(centers=centers, xy_sizes=xy_sizes) - - initial_centers = {oid: c.copy() for oid, c in centers.items()} - - # Snap initial grid positions as 9‑grid spring targets - grid_spring_targets: dict[str, np.ndarray] = { - oid: initial_centers[oid].copy() - for oid in grid_targets - if oid in initial_centers - } - - # Optimize positions and remove mesh AABB collisions. - optimized = _optimize_text_layout_slp( - object_ids=object_ids, - xy_sizes=xy_sizes, - initial_centers=initial_centers, - left_of_edges=left_of_closed, - front_of_edges=front_of_closed, - grid_spring_targets=grid_spring_targets, - padding_ratio=padding_ratio, - ) - centers = optimized["centers"] - optimization_metadata = optimized["metadata"] - - # Collect layout metadata. - metadata = { - "method": "transitive_closure_longest_path_with_9grid_and_sa", - "grid_spacing": grid_spacing, - "auto_center_oid": auto_center_oid, - "has_explicit_center": has_explicit_center, - "table_constraint_count": len(grid_targets), - "left_of_count": len(left_of_edges), - "left_of_closed_count": len(left_of_closed), - "front_of_count": len(front_of_edges), - "front_of_closed_count": len(front_of_closed), - "free_object_count": len(free_objects), - "x_ranks": {oid: x_rank.get(oid, 0) for oid in object_ids}, - "y_ranks": {oid: y_rank.get(oid, 0) for oid in object_ids}, - "optimization": optimization_metadata, - } - return { - "centers": centers, - "initial_centers": initial_centers, - "metadata": metadata, - } diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/managers/text_layout_manager/optimization.py b/embodichain/gen_sim/prompt2scene/agent_tools/managers/text_layout_manager/optimization.py deleted file mode 100644 index b8915fc4c..000000000 --- a/embodichain/gen_sim/prompt2scene/agent_tools/managers/text_layout_manager/optimization.py +++ /dev/null @@ -1,404 +0,0 @@ -# ---------------------------------------------------------------------------- -# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ---------------------------------------------------------------------------- - -from __future__ import annotations - -from typing import Any - -import numpy as np -from scipy.optimize import minimize - -from embodichain.gen_sim.prompt2scene.agent_tools.managers.optimization_manager import ( - _center_xy_aabb_layout, - _footprint_layout_diagnostics, - _xy_aabb_overlap, - _xy_union_bounds, -) - -__all__ = ["_optimize_text_layout_slp"] - -# SLSQP solve options — matching the original example_optimization SA pipeline. -_SLSQP_OPTIONS: dict[str, Any] = {"maxiter": 500, "ftol": 1e-6, "disp": False} - -# Objective weights (relations are hard constraints, not in the objective). -_WEIGHTS: dict[str, float] = { - "seed": 1.0, - "overlap": 200.0, - "grid": 100.0, -} - - -def _optimize_text_layout_slp( - *, - object_ids: list[str], - xy_sizes: dict[str, np.ndarray], - initial_centers: dict[str, np.ndarray], - left_of_edges: list[tuple[str, str]], - front_of_edges: list[tuple[str, str]], - grid_spring_targets: dict[str, np.ndarray], - padding_ratio: float, -) -> dict[str, Any]: - """Optimize 2D centres with scipy SLSQP, then remove mesh AABB overlap. - - Mirroring the original example_optimization/SA pipeline: - - left_of / front_of → linear inequality constraints - - bounding box → variable bounds (2× initial union) - - seed / overlap / grid → soft penalties in the objective - - post‑solve collision cleanup on actual footprint AABBs - """ - if not object_ids: - return { - "centers": {}, - "metadata": { - "method": "text_slsqp_then_mesh_aabb_collision_removal", - "slsqp_iterations": 0, - "collision_iterations": 0, - }, - } - - max_extent = max( - float(max(xy_sizes[oid][0], xy_sizes[oid][1])) for oid in object_ids - ) - padding = max(max_extent * padding_ratio, 1e-3) - - initial_centers = { - oid: np.asarray(initial_centers[oid], dtype=np.float64).copy() - for oid in object_ids - } - initial_union_bounds = _xy_union_bounds( - centers=initial_centers, - xy_sizes=xy_sizes, - ) - - index_by_id = {oid: i for i, oid in enumerate(object_ids)} - x0 = _pack_centers(object_ids, initial_centers) - - # Build linear inequality constraints for left_of and front_of. - constraints: list[dict[str, Any]] = [] - _build_relation_constraints( - constraints=constraints, - object_ids=object_ids, - index_by_id=index_by_id, - xy_sizes=xy_sizes, - left_of_edges=left_of_edges, - front_of_edges=front_of_edges, - padding=padding, - ) - - # Bound variables to twice the initial union size. - init_size = initial_union_bounds[1] - initial_union_bounds[0] - margin = init_size * 0.5 # 50 % each side → 2× total - bounds = [] - for oid in object_ids: - bounds.append( - ( - float(initial_union_bounds[0, 0] - margin[0]), - float(initial_union_bounds[1, 0] + margin[0]), - ) - ) # x - bounds.append( - ( - float(initial_union_bounds[0, 1] - margin[1]), - float(initial_union_bounds[1, 1] + margin[1]), - ) - ) # y - - # Define the optimization objective. - def _objective(xvec: np.ndarray) -> float: - centers = _unpack_centers(object_ids, xvec) - loss = 0.0 - - # seed: stay close to initial positions - for oid in object_ids: - delta = centers[oid] - initial_centers[oid] - loss += _WEIGHTS["seed"] * float(np.dot(delta, delta)) - - # overlap: AABB overlap area penalty - for i, oid in enumerate(object_ids): - for other_id in object_ids[i + 1 :]: - ov = _xy_aabb_overlap( - center_a=centers[oid], - size_a=xy_sizes[oid], - center_b=centers[other_id], - size_b=xy_sizes[other_id], - padding=padding, - ) - if ov is not None: - loss += _WEIGHTS["overlap"] * float(ov[0] * ov[1]) - - # grid: spring toward 9‑grid targets - for oid, target in grid_spring_targets.items(): - if oid not in centers: - continue - delta = centers[oid] - target - loss += _WEIGHTS["grid"] * float(np.dot(delta, delta)) - - return float(loss) - - # Solve the constrained optimization problem. - slsqp_result: dict[str, Any] = {"success": False, "nit": 0, "message": ""} - try: - result = minimize( - _objective, - x0, - method="SLSQP", - bounds=bounds, - constraints=constraints, - options=_SLSQP_OPTIONS, - ) - slsqp_result = { - "success": bool(result.success), - "nit": int(getattr(result, "nit", 0)), - "message": str(result.message), - "fun": float(result.fun) if result.fun is not None else None, - } - if result.success: - x_opt = result.x - else: - # SLSQP failed — fall back to seed positions - x_opt = x0.copy() - except Exception: - x_opt = x0.copy() - slsqp_result["message"] = "SLSQP raised an exception; using seed positions." - - centers = _unpack_centers(object_ids, x_opt) - centers = _center_xy_aabb_layout(centers=centers, xy_sizes=xy_sizes) - - # Remove residual collisions. - centers, collision_metadata = _remove_mesh_aabb_collisions( - object_ids=object_ids, - xy_sizes=xy_sizes, - centers=centers, - initial_centers=initial_centers, - left_of_edges=left_of_edges, - front_of_edges=front_of_edges, - padding=padding, - ) - centers = _center_xy_aabb_layout(centers=centers, xy_sizes=xy_sizes) - - # Collect optimization metadata. - diagnostics = _footprint_layout_diagnostics( - object_ids=object_ids, - centers=centers, - initial_centers=initial_centers, - xy_sizes=xy_sizes, - padding=padding, - initial_union_bounds=initial_union_bounds, - ) - metadata: dict[str, Any] = { - "method": "text_slsqp_then_mesh_aabb_collision_removal", - "relation_usage": "left_of_front_of_hard_constraints", - "padding": float(padding), - "padding_ratio": float(padding_ratio), - "weights": dict(_WEIGHTS), - "slsqp": slsqp_result, - "bounds_expansion": 2.0, - "initial_union_size": init_size.tolist(), - **collision_metadata, - "final_centers": { - oid: centers[oid].tolist() for oid in object_ids - }, - **diagnostics, - } - return {"centers": centers, "metadata": metadata} - - -# Build relation constraints. - - -def _build_relation_constraints( - *, - constraints: list[dict[str, Any]], - object_ids: list[str], - index_by_id: dict[str, int], - xy_sizes: dict[str, np.ndarray], - left_of_edges: list[tuple[str, str]], - front_of_edges: list[tuple[str, str]], - padding: float, -) -> None: - """Append SLSQP inequality constraints for left_of / front_of edges.""" - - for subject, obj in left_of_edges: - if subject not in index_by_id or obj not in index_by_id: - continue - i_a = index_by_id[subject] - i_b = index_by_id[obj] - # A.x + gap ≤ B.x → B.x - A.x - gap ≥ 0 - gap = ( - 0.5 * float(xy_sizes[subject][0]) - + 0.5 * float(xy_sizes[obj][0]) - + padding - ) - constraints.append( - { - "type": "ineq", - "fun": lambda x, ia=i_a, ib=i_b, g=gap: float( - x[2 * ib] - x[2 * ia] - g - ), - } - ) - - for subject, obj in front_of_edges: - if subject not in index_by_id or obj not in index_by_id: - continue - i_a = index_by_id[subject] - i_b = index_by_id[obj] - # A.y + gap ≤ B.y → B.y - A.y - gap ≥ 0 - gap = ( - 0.5 * float(xy_sizes[subject][1]) - + 0.5 * float(xy_sizes[obj][1]) - + padding - ) - constraints.append( - { - "type": "ineq", - "fun": lambda x, ia=i_a, ib=i_b, g=gap: float( - x[2 * ib + 1] - x[2 * ia + 1] - g - ), - } - ) - - -# Remove AABB collisions. - - -def _remove_mesh_aabb_collisions( - *, - object_ids: list[str], - xy_sizes: dict[str, np.ndarray], - centers: dict[str, np.ndarray], - initial_centers: dict[str, np.ndarray], - left_of_edges: list[tuple[str, str]], - front_of_edges: list[tuple[str, str]], - padding: float, -) -> tuple[dict[str, np.ndarray], dict[str, Any]]: - relation_pairs = set(left_of_edges + front_of_edges) - relation_pairs.update((b, a) for a, b in left_of_edges + front_of_edges) - current = { - oid: np.asarray(center, dtype=np.float64).copy() - for oid, center in centers.items() - } - max_rounds = 80 - total_pushes = 0 - last_overlap_count = 0 - - for iteration in range(max_rounds): - overlaps = _mesh_aabb_collision_pairs( - object_ids=object_ids, - xy_sizes=xy_sizes, - centers=current, - padding=padding, - ) - last_overlap_count = len(overlaps) - if not overlaps: - return current, { - "collision_iterations": iteration, - "collision_pushes": total_pushes, - "collision_remaining": 0, - "collision_removal": "iterative_mesh_aabb_push", - } - for item in overlaps: - object_a = item["object"] - object_b = item["other"] - axis = int(item["axis"]) - sign = -1.0 if current[object_a][axis] <= current[object_b][axis] else 1.0 - amount = 0.5 * (float(item["overlap"]) + 1.0e-6) - if (object_a, object_b) in relation_pairs: - current[object_a][axis] += sign * amount - current[object_b][axis] -= sign * amount - else: - drift_a = np.linalg.norm( - current[object_a] - initial_centers[object_a] - ) - drift_b = np.linalg.norm( - current[object_b] - initial_centers[object_b] - ) - if drift_a <= drift_b: - current[object_a][axis] += sign * amount * 1.25 - current[object_b][axis] -= sign * amount * 0.75 - else: - current[object_a][axis] += sign * amount * 0.75 - current[object_b][axis] -= sign * amount * 1.25 - total_pushes += 1 - current = _center_xy_aabb_layout(centers=current, xy_sizes=xy_sizes) - - return current, { - "collision_iterations": max_rounds, - "collision_pushes": total_pushes, - "collision_remaining": last_overlap_count, - "collision_removal": "iterative_mesh_aabb_push", - } - - -def _mesh_aabb_collision_pairs( - *, - object_ids: list[str], - xy_sizes: dict[str, np.ndarray], - centers: dict[str, np.ndarray], - padding: float, -) -> list[dict[str, Any]]: - pairs: list[dict[str, Any]] = [] - for i, oid in enumerate(object_ids): - for other_id in object_ids[i + 1 :]: - ov = _xy_aabb_overlap( - center_a=centers[oid], - size_a=xy_sizes[oid], - center_b=centers[other_id], - size_b=xy_sizes[other_id], - padding=padding, - ) - if ov is None: - continue - axis = 0 if ov[0] <= ov[1] else 1 - pairs.append( - { - "object": oid, - "other": other_id, - "axis": axis, - "overlap": float(ov[axis]), - "overlap_x": float(ov[0]), - "overlap_y": float(ov[1]), - } - ) - pairs.sort(key=lambda item: item["overlap"], reverse=True) - return pairs - - -# Pack and unpack center coordinates. - - -def _pack_centers( - object_ids: list[str], - centers: dict[str, np.ndarray], -) -> np.ndarray: - values: list[float] = [] - for oid in object_ids: - c = np.asarray(centers[oid], dtype=np.float64) - values.extend([float(c[0]), float(c[1])]) - return np.asarray(values, dtype=np.float64) - - -def _unpack_centers( - object_ids: list[str], - xvec: np.ndarray, -) -> dict[str, np.ndarray]: - return { - oid: np.asarray( - [xvec[2 * i], xvec[2 * i + 1]], - dtype=np.float64, - ) - for i, oid in enumerate(object_ids) - } diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/managers/image_scene_manager/alignment.py b/embodichain/gen_sim/prompt2scene/agent_tools/tools/image_layout_alignment.py similarity index 89% rename from embodichain/gen_sim/prompt2scene/agent_tools/managers/image_scene_manager/alignment.py rename to embodichain/gen_sim/prompt2scene/agent_tools/tools/image_layout_alignment.py index 6d7084f44..ede21b08b 100644 --- a/embodichain/gen_sim/prompt2scene/agent_tools/managers/image_scene_manager/alignment.py +++ b/embodichain/gen_sim/prompt2scene/agent_tools/tools/image_layout_alignment.py @@ -23,17 +23,17 @@ import numpy as np -from embodichain.gen_sim.prompt2scene.workflows.llm_output import ( +from embodichain.gen_sim.prompt2scene.llms.llm_output import ( call_structured_json_model_step, ) -from embodichain.gen_sim.prompt2scene.agent_tools.managers.image_scene_manager.prompts import ( +from embodichain.gen_sim.prompt2scene.prompts.builders import ( build_up_down_flip_check_messages, ) -from embodichain.gen_sim.prompt2scene.agent_tools.managers.metric_scale_manager import ( +from embodichain.gen_sim.prompt2scene.agent_tools.managers.simready_manager import ( GlobalMetricScaleRequest, MetricScaleManager, ) -from embodichain.gen_sim.prompt2scene.agent_tools.managers.image_scene_manager.schemas import ( +from embodichain.gen_sim.prompt2scene.prompts.schemas import ( UP_DOWN_FLIP_CHECK_JSON_SCHEMA, ) @@ -47,26 +47,14 @@ MatplotlibManager, RenderImageComparisonRequest, ) -from embodichain.gen_sim.prompt2scene.agent_tools.managers.geometry_manager.scene_geometry import ( - _aabb_center, - _copy_scene_with_transform, - _estimate_support_normal, - _load_scene_with_transform, - _matrix_from_json, - _rotation_between_vectors, - _scale_transform, - _scene_to_mesh, - _support_normal_flip_transform, - _xy_aabb_center, - _z_up_to_glb_y_up_transform, - _z_yaw_transform, +from embodichain.gen_sim.prompt2scene.agent_tools.managers.geometry_manager import ( + GeometryManager, ) from embodichain.gen_sim.prompt2scene.utils.io import ( relative_path, ) -from embodichain.gen_sim.prompt2scene.agent_tools.managers.optimization_manager import ( - _object_scenes_xy_aabb_manifest, - _settle_and_pack_object_footprints, +from embodichain.gen_sim.prompt2scene.agent_tools.managers.layout_manager import ( + LayoutManager, ) __all__ = ["_export_support_aligned_layout_glbs"] @@ -106,7 +94,7 @@ def _export_support_aligned_layout_glbs( raise FileNotFoundError( f"Support reference table GLB not found: {support_reference_path}" ) - support_reference_transform = _matrix_from_json( + support_reference_transform = GeometryManager.matrix_from_json( table.get("support_reference_transform_matrix") or table.get("transform_matrix"), name="table.support_reference_transform_matrix", @@ -119,9 +107,9 @@ def _export_support_aligned_layout_glbs( object_scenes = [ ( object_id, - _load_scene_with_transform( + GeometryManager.load_scene_with_transform( path=path, - transform=_matrix_from_json( + transform=GeometryManager.matrix_from_json( transform, name=f"{object_id}.transform_matrix", ), @@ -130,9 +118,9 @@ def _export_support_aligned_layout_glbs( ) for object_id, path, transform in object_paths ] - table_mesh = _scene_to_mesh(support_reference_scene, trimesh=trimesh) - support_normal = _estimate_support_normal(table_mesh) - normal_alignment = _rotation_between_vectors( + table_mesh = GeometryManager.scene_to_mesh(support_reference_scene, trimesh=trimesh) + support_normal = GeometryManager.estimate_support_normal(table_mesh) + normal_alignment = GeometryManager.rotation_between_vectors( support_normal, np.array([0.0, 0.0, 1.0]), ) @@ -141,7 +129,8 @@ def _export_support_aligned_layout_glbs( scene.apply_transform(normal_alignment) object_bounds = [ - _scene_to_mesh(scene, trimesh=trimesh).bounds for _, scene in object_scenes + GeometryManager.scene_to_mesh(scene, trimesh=trimesh).bounds + for _, scene in object_scenes ] clutter_bounds = np.vstack( [ @@ -191,12 +180,14 @@ def _export_support_aligned_layout_glbs( object_scenes=object_scenes, ) ) - metric_scale_transform = _scale_transform(global_metric_scale["scale_factor"]) + metric_scale_transform = GeometryManager.scale_transform( + global_metric_scale["scale_factor"] + ) if float(global_metric_scale["scale_factor"]) != 1.0: for _, scene in object_scenes: scene.apply_transform(metric_scale_transform) - footprint_result = _settle_and_pack_object_footprints( + footprint_result = LayoutManager.settle_and_pack_object_footprints( object_scenes=object_scenes, output_dir=output_dir / "footprint_layout", output_root=output_root, @@ -204,11 +195,14 @@ def _export_support_aligned_layout_glbs( ) object_scenes = footprint_result["object_scenes"] - output_axis_transform = _z_up_to_glb_y_up_transform() + output_axis_transform = GeometryManager.z_up_to_glb_y_up_transform() object_outputs = [] for object_id, scene in object_scenes: object_output = output_dir / f"{object_id}_aligned.glb" - _copy_scene_with_transform(scene, output_axis_transform).export(object_output) + GeometryManager.copy_scene_with_transform( + scene, + output_axis_transform, + ).export(object_output) object_outputs.append( { "id": object_id, @@ -218,7 +212,7 @@ def _export_support_aligned_layout_glbs( alignment_matrix = selected_extra_transform @ center_transform @ normal_alignment scaled_alignment_matrix = metric_scale_transform @ alignment_matrix - final_clutter_aabb_2d_cm = _object_scenes_xy_aabb_manifest( + final_clutter_aabb_2d_cm = LayoutManager.object_scenes_xy_aabb_manifest( object_scenes=object_scenes, trimesh=trimesh, unit_scale=100.0, @@ -270,7 +264,7 @@ def _build_up_down_alignment_candidates( spatial_relations: list[dict[str, Any]], trimesh: Any, ) -> dict[str, dict[str, Any]]: - flip_transform = _support_normal_flip_transform( + flip_transform = GeometryManager.support_normal_flip_transform( support_normal=support_normal, normal_alignment=normal_alignment, ) @@ -281,12 +275,15 @@ def _build_up_down_alignment_candidates( ("flipped", flip_transform), ]: candidate_object_scenes = [ - (object_id, _copy_scene_with_transform(scene, pre_yaw_transform)) + ( + object_id, + GeometryManager.copy_scene_with_transform(scene, pre_yaw_transform), + ) for object_id, scene in object_scenes ] object_bounds = { object_id: np.asarray( - _scene_to_mesh(scene, trimesh=trimesh).bounds, + GeometryManager.scene_to_mesh(scene, trimesh=trimesh).bounds, dtype=np.float64, ) for object_id, scene in candidate_object_scenes @@ -295,7 +292,7 @@ def _build_up_down_alignment_candidates( object_bounds=object_bounds, relations=directional_relations, ) - yaw_transform = _z_yaw_transform( + yaw_transform = GeometryManager.z_yaw_transform( float(yaw_metadata["yaw_degrees"]), ) for _, scene in candidate_object_scenes: @@ -325,14 +322,15 @@ def _best_spatial_yaw( } object_centers = { - object_id: _aabb_center(bounds) for object_id, bounds in object_bounds.items() + object_id: GeometryManager.aabb_center(bounds) + for object_id, bounds in object_bounds.items() } best_yaw = 0 best_score = -1 best_raw_gap_sum = float("-inf") best_relation_scores: list[dict[str, Any]] = [] for yaw_degrees in range(360): - rotation = _z_yaw_transform(float(yaw_degrees)) + rotation = GeometryManager.z_yaw_transform(float(yaw_degrees)) rotated_centers = { object_id: _transform_point(rotation, center) for object_id, center in object_centers.items() diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/tools/image_scene_asset_generation.py b/embodichain/gen_sim/prompt2scene/agent_tools/tools/image_scene_asset_generation.py index 9d3e42f1d..4bc8cbb5b 100644 --- a/embodichain/gen_sim/prompt2scene/agent_tools/tools/image_scene_asset_generation.py +++ b/embodichain/gen_sim/prompt2scene/agent_tools/tools/image_scene_asset_generation.py @@ -41,7 +41,7 @@ AssetImageToRgbaRequest, ImageSegmentationManager, ) -from embodichain.gen_sim.prompt2scene.agent_tools.managers.image_scene_manager import ( +from embodichain.gen_sim.prompt2scene.agent_tools.tools.image_layout_alignment import ( _export_support_aligned_layout_glbs, ) from embodichain.gen_sim.prompt2scene.agent_tools.managers.simready_manager import ( @@ -49,24 +49,20 @@ MakeTableSimreadyRequest, SimreadyManager, ) -from embodichain.gen_sim.prompt2scene.agent_tools.managers.metric_scale_manager import ( +from embodichain.gen_sim.prompt2scene.agent_tools.managers.simready_manager import ( METRIC_SCALE_ENABLED, EstimateMetricScalesRequest, + IMAGE_METRIC_SCALE_JSON_SCHEMA, MetricScaleManager, MetricScaleObjectInput, + build_image_metric_scale_messages, ) -from embodichain.gen_sim.prompt2scene.agent_tools.managers.geometry_manager.scene_geometry import ( - _compose_sam3d_multi_object_transform, +from embodichain.gen_sim.prompt2scene.agent_tools.managers.geometry_manager import ( + GeometryManager, ) -from embodichain.gen_sim.prompt2scene.agent_tools.managers.image_scene_manager import ( +from embodichain.gen_sim.prompt2scene.agent_tools.tools.layout_manifests import ( _write_multi_object_layout_manifests, ) -from embodichain.gen_sim.prompt2scene.agent_tools.managers.image_scene_manager.prompts import ( - build_image_metric_scale_messages, -) -from embodichain.gen_sim.prompt2scene.agent_tools.managers.image_scene_manager.schemas import ( - IMAGE_METRIC_SCALE_JSON_SCHEMA, -) from embodichain.gen_sim.prompt2scene.utils.io import ( relative_path, ) @@ -212,7 +208,7 @@ def generate_image_scene_assets( status_parts: list[str] = [] transform_matrix: list[list[float]] = [] try: - transform = _compose_sam3d_multi_object_transform( + transform = GeometryManager.compose_sam3d_multi_object_transform( rotation_quaternion_wxyz=generated.rotation_quaternion_wxyz, translation=generated.translation, scale=generated.scale, diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/tools/image_segment_filter.py b/embodichain/gen_sim/prompt2scene/agent_tools/tools/image_segment_filter.py new file mode 100644 index 000000000..28d3a99ab --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/agent_tools/tools/image_segment_filter.py @@ -0,0 +1,189 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from pathlib import Path +from typing import Any, Callable + +from embodichain.gen_sim.prompt2scene.agent_tools.clients.image_segmentation_client import ( + draw_numbered_masks, + sort_segments_by_bbox, +) +from embodichain.gen_sim.prompt2scene.llms.llm_output import ( + call_structured_json_model_step, + is_model_output_error, +) +from embodichain.gen_sim.prompt2scene.prompts.builders import ( + build_filter_extra_instances_messages, +) +from embodichain.gen_sim.prompt2scene.prompts.schemas import ( + FILTER_EXTRA_INSTANCES_JSON_SCHEMA, +) +from embodichain.gen_sim.prompt2scene.utils import log_api_request_start, log +from embodichain.gen_sim.prompt2scene.workflows.stage_errors import ( + format_attempt_error, +) + +__all__ = [ + "filter_group_segments_with_vlm", + "filter_segments_with_vlm", + "remove_extra_numbered_segments", +] + +DebugWriter = Callable[[str, str, dict[str, Any]], Path] + + +def remove_extra_numbered_segments( + *, + segments: list[dict[str, Any]], + raw_model_output: dict[str, Any], +) -> list[dict[str, Any]]: + """Remove numbered masks flagged as extra by the VLM.""" + extra_numbers = raw_model_output.get("extra_instance_numbers") + if not isinstance(extra_numbers, list): + raise ValueError("extra_instance_numbers must be a list.") + extra_indices = {int(number) - 1 for number in extra_numbers} + if any(index < 0 or index >= len(segments) for index in extra_indices): + raise ValueError("VLM returned an out-of-range extra mask number.") + return [ + segment + for index, segment in enumerate(segments) + if index not in extra_indices + ] + + +def filter_group_segments_with_vlm( + *, + llm: Any, + image_path: Path, + step_name: str, + group: dict[str, Any], + segments: list[dict[str, Any]], + stage: str, + debug_round_name: str, + debug_round_dir: Path, + write_debug_json: DebugWriter, +) -> list[dict[str, Any]]: + """Ask VLM to remove wrong or duplicate instances from one SAM3 result. + + All path concerns are injected via *step_name*, *debug_round_name*, + *debug_round_dir*, and *write_debug_json* so the tool does not depend + on workflow internals. + """ + segments = sort_segments_by_bbox(segments) + if not segments: + return segments + + debug_image_path = draw_numbered_masks( + image_path=image_path, + segments=segments, + output_path=debug_round_dir / "mask.png", + ) + debug_images = list(group.get("debug_images") or []) + if str(debug_image_path) not in debug_images: + debug_images.append(str(debug_image_path)) + group["debug_images"] = debug_images + + log_api_request_start( + step=step_name, + request=f"vlm_filter_{stage}", + debug_image=str(debug_image_path), + ) + messages = build_filter_extra_instances_messages( + debug_image_path=debug_image_path, + name=group["name"], + description=group["description"], + expected_count=group["expected_count"], + class_candidate=group["class_candidate"], + ) + raw_model_output = call_structured_json_model_step( + llm=llm, + schema=FILTER_EXTRA_INSTANCES_JSON_SCHEMA, + messages=messages, + context=f"Image relation {stage} segmentation filtering", + + + attempt_count=0, + raw_output_writer=lambda payload: write_debug_json( + round_name=debug_round_name, + filename="raw_model_output.json", + payload=payload, + ), + ) + return remove_extra_numbered_segments( + segments=segments, + raw_model_output=raw_model_output, + ) + + +def filter_segments_with_vlm( + *, + llm: Any, + image_path: Path, + step_name: str, + segment_groups: list[dict[str, Any]], + attempt_count: int, + errors: list[str], + stage: str, + next_debug_round_name: Callable[[str], str], + debug_round_dir: Callable[[str], Path], + write_debug_json: DebugWriter, +) -> dict[str, object]: + """Filter all segment groups with VLM and return an updated state patch. + + All path concerns are injected via callbacks so the tool does not + depend on workflow internals. + """ + result_groups: list[dict[str, Any]] = [] + current_attempt = attempt_count + 1 + + try: + for group in segment_groups: + group = dict(group) + name = str(group.get("name", "unknown")) + round_name = next_debug_round_name(f"{stage}_{name}") + round_dir = debug_round_dir(round_name) + group["segments"] = filter_group_segments_with_vlm( + llm=llm, + image_path=image_path, + step_name=step_name, + group=group, + segments=group["segments"], + stage=stage, + debug_round_name=round_name, + debug_round_dir=round_dir, + write_debug_json=write_debug_json, + ) + result_groups.append(group) + except Exception as exc: + if is_model_output_error(exc) or isinstance(exc, ValueError): + error = format_attempt_error( + "Image relations VLM filter", current_attempt, exc + ) + log.log_warning(error) + return { + "attempt_count": current_attempt, + "last_error": error, + "errors": errors + [error], + } + raise + + return { + "attempt_count": current_attempt, + "segment_groups": result_groups, + "last_error": None, + } diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/managers/image_scene_manager/manifests.py b/embodichain/gen_sim/prompt2scene/agent_tools/tools/layout_manifests.py similarity index 95% rename from embodichain/gen_sim/prompt2scene/agent_tools/managers/image_scene_manager/manifests.py rename to embodichain/gen_sim/prompt2scene/agent_tools/tools/layout_manifests.py index 6ae379c3e..6fb9e0c54 100644 --- a/embodichain/gen_sim/prompt2scene/agent_tools/managers/image_scene_manager/manifests.py +++ b/embodichain/gen_sim/prompt2scene/agent_tools/tools/layout_manifests.py @@ -23,10 +23,8 @@ from embodichain.gen_sim.prompt2scene.utils.io import ( relative_path, ) -from embodichain.gen_sim.prompt2scene.agent_tools.managers.geometry_manager.scene_geometry import ( - _compose_json_matrices, - _compose_simready_to_aligned_matrix, - _decompose_transform_matrix, +from embodichain.gen_sim.prompt2scene.agent_tools.managers.geometry_manager import ( + GeometryManager, ) from embodichain.gen_sim.prompt2scene.utils.io import write_json @@ -172,16 +170,16 @@ def _simready_to_aligned_manifest_item( item_id, alignment_matrix, ) - raw_to_aligned_matrix = _compose_json_matrices( + raw_to_aligned_matrix = GeometryManager.compose_json_matrices( glb_output_axis_transform, item_alignment_matrix, sam3d_transform, ) - simready_to_aligned_matrix = _compose_simready_to_aligned_matrix( + simready_to_aligned_matrix = GeometryManager.compose_simready_to_aligned_matrix( raw_to_aligned_matrix=raw_to_aligned_matrix, raw_to_simready_matrix=item.get("raw_to_simready_glb_matrix", []), ) - decomposed = _decompose_transform_matrix(simready_to_aligned_matrix) + decomposed = GeometryManager.decompose_transform_matrix(simready_to_aligned_matrix) return { "id": item_id, "name": item.get("name", ""), diff --git a/embodichain/gen_sim/prompt2scene/workflows/spatial.py b/embodichain/gen_sim/prompt2scene/agent_tools/tools/spatial_relations.py similarity index 100% rename from embodichain/gen_sim/prompt2scene/workflows/spatial.py rename to embodichain/gen_sim/prompt2scene/agent_tools/tools/spatial_relations.py diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/managers/table_clutter_fit_manager/manager.py b/embodichain/gen_sim/prompt2scene/agent_tools/tools/table_clutter_fit.py similarity index 80% rename from embodichain/gen_sim/prompt2scene/agent_tools/managers/table_clutter_fit_manager/manager.py rename to embodichain/gen_sim/prompt2scene/agent_tools/tools/table_clutter_fit.py index eeb79a182..e3e6f5296 100644 --- a/embodichain/gen_sim/prompt2scene/agent_tools/managers/table_clutter_fit_manager/manager.py +++ b/embodichain/gen_sim/prompt2scene/agent_tools/tools/table_clutter_fit.py @@ -17,24 +17,14 @@ from __future__ import annotations -import json import tempfile from pathlib import Path from typing import Any import numpy as np -from embodichain.gen_sim.prompt2scene.utils.io import relative_path -from embodichain.gen_sim.prompt2scene.agent_tools.managers.geometry_manager.scene_geometry import ( - _copy_scene_with_transform, - _scene_to_mesh, - _z_up_to_glb_y_up_transform, - _detect_table_fit_support_quad, - _load_table_fit_scene_internal_z, - _table_fit_bounds_xy_manifest, - _table_fit_safe_positive_ratio, - _table_fit_scene_union_bounds, - _table_fit_uniform_xy_scale_transform, +from embodichain.gen_sim.prompt2scene.agent_tools.managers.geometry_manager import ( + GeometryManager, ) from embodichain.gen_sim.prompt2scene.agent_tools.managers.simulation_manager import ( SimulationManager, @@ -65,7 +55,7 @@ def _gravity_settle_table_fit_internal_z_scene( with tempfile.TemporaryDirectory(prefix="p2s_table_fit_gravity_") as tmp: tmp_path = Path(tmp) pre_gravity = tmp_path / "table_pre_gravity.glb" - _copy_scene_with_transform(scene, z_to_y).export(pre_gravity) + GeometryManager.copy_scene_with_transform(scene, z_to_y).export(pre_gravity) result = sim.run_gravity_simulation( GravityDropRequest( glb_path=pre_gravity, @@ -78,20 +68,14 @@ def _gravity_settle_table_fit_internal_z_scene( return settled -def _write_table_fit_json(path: Path, data: dict[str, Any]) -> None: - path.parent.mkdir(parents=True, exist_ok=True) - path.write_text( - json.dumps(data, ensure_ascii=False, indent=2) + "\n", - encoding="utf-8", - ) - - def fit_table_to_clutter( *, table_result: dict[str, Any], clutter_result: dict[str, Any], output_root: Path, output_dir: Path, + table_output_path: Path | None = None, + object_output_paths: dict[str, Path] | None = None, margin_cm: float = 10.0, support_occupancy_ratio: float = 0.80, object_coverage_percent: int | None = None, @@ -115,6 +99,14 @@ def fit_table_to_clutter( output_root = output_root.expanduser().resolve() output_dir = output_dir.expanduser().resolve() output_dir.mkdir(parents=True, exist_ok=True) + if table_output_path is None: + table_output_path = output_dir / "table_fit_to_clutter.glb" + table_output_path = table_output_path.expanduser().resolve() + table_output_path.parent.mkdir(parents=True, exist_ok=True) + object_output_paths = { + str(key): path.expanduser().resolve() + for key, path in (object_output_paths or {}).items() + } # Resolve the table geometry. table_simready_path = _resolve_generated_path( @@ -145,30 +137,37 @@ def fit_table_to_clutter( if not object_glb_paths: raise ValueError("No valid settled object GLBs for table fitting.") - z_to_y = _z_up_to_glb_y_up_transform() + z_to_y = GeometryManager.z_up_to_glb_y_up_transform() y_to_z = np.linalg.inv(z_to_y) # Load the table and detect its support surface. - table_scene = _load_table_fit_scene_internal_z( + table_scene = GeometryManager.load_table_fit_scene_internal_z( table_simready_path, trimesh=trimesh, y_to_z=y_to_z, ) - table_mesh = _scene_to_mesh(table_scene, trimesh=trimesh) + table_mesh = GeometryManager.scene_to_mesh(table_scene, trimesh=trimesh) clutter_aabb = clutter_result.get("clutter_2d_aabb_cm") or {} clutter_size = clutter_aabb.get("size_xy", [1.0, 1.0]) target_aspect = float(clutter_size[0]) / max(float(clutter_size[1]), 1.0e-6) - initial_support = _detect_table_fit_support_quad( + initial_support = GeometryManager.detect_table_fit_support_quad( table_mesh, target_aspect=target_aspect, ) # Load the clutter scenes. clutter_scenes = [ - (oid, _load_table_fit_scene_internal_z(path, trimesh=trimesh, y_to_z=y_to_z)) + ( + oid, + GeometryManager.load_table_fit_scene_internal_z( + path, + trimesh=trimesh, + y_to_z=y_to_z, + ), + ) for oid, path in object_glb_paths ] - clutter_bounds = _table_fit_scene_union_bounds( + clutter_bounds = GeometryManager.table_fit_scene_union_bounds( [scene for _, scene in clutter_scenes], trimesh=trimesh, ) @@ -182,10 +181,16 @@ def fit_table_to_clutter( occupancy = float(np.clip(support_occupancy_ratio, 0.1, 1.0)) required_size_cm = clutter_size_cm / occupancy + 2.0 * float(margin_cm) support_size_cm = np.asarray(initial_support["size_xy"], dtype=np.float64) * 100.0 - scale_x = _table_fit_safe_positive_ratio(required_size_cm[0], support_size_cm[0]) - scale_y = _table_fit_safe_positive_ratio(required_size_cm[1], support_size_cm[1]) + scale_x = GeometryManager.table_fit_safe_positive_ratio( + required_size_cm[0], + support_size_cm[0], + ) + scale_y = GeometryManager.table_fit_safe_positive_ratio( + required_size_cm[1], + support_size_cm[1], + ) uniform_scale = max(scale_x, scale_y) - table_scale_transform = _table_fit_uniform_xy_scale_transform( + table_scale_transform = GeometryManager.table_fit_uniform_xy_scale_transform( center_xy=np.asarray(initial_support["center_xy"], dtype=np.float64), scale=uniform_scale, ) @@ -200,8 +205,8 @@ def fit_table_to_clutter( ) # Reposition the table at the origin. - final_table_mesh = _scene_to_mesh(table_scene, trimesh=trimesh) - final_support = _detect_table_fit_support_quad( + final_table_mesh = GeometryManager.scene_to_mesh(table_scene, trimesh=trimesh) + final_support = GeometryManager.detect_table_fit_support_quad( final_table_mesh, target_aspect=float(required_size_cm[0] / max(required_size_cm[1], 1.0e-6)), ) @@ -218,7 +223,10 @@ def fit_table_to_clutter( # Use the highest point of the table mesh (after scaling + gravity + shift) # rather than the support-plane mean Z, so that thin objects sit above the # actual geometry even when the tabletop has slight unevenness. - _table_mesh_after_shift = _scene_to_mesh(table_scene, trimesh=trimesh) + _table_mesh_after_shift = GeometryManager.scene_to_mesh( + table_scene, + trimesh=trimesh, + ) _table_max_z = float( np.asarray(_table_mesh_after_shift.bounds, dtype=np.float64)[1, 2] ) @@ -227,13 +235,13 @@ def fit_table_to_clutter( # Place the objects on the table. placed_objects: list[dict[str, Any]] = [] shifted_clutter: list[tuple[str, Any]] = [] - clutter_after = _table_fit_scene_union_bounds( + clutter_after = GeometryManager.table_fit_scene_union_bounds( [scene for _, scene in clutter_scenes], trimesh=trimesh, ) clutter_center_xy = 0.5 * (clutter_after[0, :2] + clutter_after[1, :2]) for oid, scene in clutter_scenes: - obj_mesh = _scene_to_mesh(scene, trimesh=trimesh) + obj_mesh = GeometryManager.scene_to_mesh(scene, trimesh=trimesh) obj_bounds = np.asarray(obj_mesh.bounds, dtype=np.float64) obj_bottom_z = float(obj_bounds[0, 2]) obj_shift = np.eye(4, dtype=np.float64) @@ -246,16 +254,18 @@ def fit_table_to_clutter( shifted_clutter.append((oid, scene)) # Export the fitted table and placed objects. - final_table_path = output_dir / "table_fit_to_clutter.glb" - _copy_scene_with_transform(table_scene, z_to_y).export(final_table_path) + GeometryManager.copy_scene_with_transform(table_scene, z_to_y).export( + table_output_path + ) for oid, scene in shifted_clutter: - object_path = output_dir / f"{oid}_on_table.glb" - _copy_scene_with_transform(scene, z_to_y).export(object_path) + object_path = object_output_paths.get(oid, output_dir / f"{oid}_on_table.glb") + object_path.parent.mkdir(parents=True, exist_ok=True) + GeometryManager.copy_scene_with_transform(scene, z_to_y).export(object_path) # Compute world-space AABB bottom-centre (sim Z-up coords) before # the scene is converted to GLB Y-up for export. This is the # reference position that gym_export uses to derive ``init_pos``. - _placed_mesh = _scene_to_mesh(scene, trimesh=trimesh) + _placed_mesh = GeometryManager.scene_to_mesh(scene, trimesh=trimesh) _placed_b = np.asarray(_placed_mesh.bounds, dtype=np.float64) world_aabb_bottom_center = [ float(0.5 * (_placed_b[0, 0] + _placed_b[1, 0])), @@ -270,12 +280,11 @@ def fit_table_to_clutter( } ) - # Write the fit manifest. - final_clutter_bounds = _table_fit_scene_union_bounds( + final_clutter_bounds = GeometryManager.table_fit_scene_union_bounds( [scene for _, scene in shifted_clutter], trimesh=trimesh, ) - final_clutter_aabb_cm = _table_fit_bounds_xy_manifest( + final_clutter_aabb_cm = GeometryManager.table_fit_bounds_xy_manifest( final_clutter_bounds, unit_scale=100.0, ) @@ -295,7 +304,7 @@ def fit_table_to_clutter( "status": "ok", "output_dir": str(output_dir), "table_simready_path": str(table_simready_path), - "table_output_path": str(final_table_path), + "table_output_path": str(table_output_path), "objects": placed_objects, "margin_cm": margin_cm, "support_occupancy_ratio": occupancy, @@ -319,9 +328,4 @@ def fit_table_to_clutter( <= float(np.asarray(final_support_centered["size_xy"])[1] * 100.0), }, } - manifest_path = output_dir / "table_fit_to_clutter_manifest.json" - _write_table_fit_json(manifest_path, manifest) - return { - "status": "ok", - "manifest_path": relative_path(manifest_path, output_root), - } + return manifest diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/tools/table_fit_scene.py b/embodichain/gen_sim/prompt2scene/agent_tools/tools/table_fit_scene.py index 273f15a65..d36cb718f 100644 --- a/embodichain/gen_sim/prompt2scene/agent_tools/tools/table_fit_scene.py +++ b/embodichain/gen_sim/prompt2scene/agent_tools/tools/table_fit_scene.py @@ -20,14 +20,29 @@ from pathlib import Path from typing import Any -from embodichain.gen_sim.prompt2scene.agent_tools.managers.table_clutter_fit_manager import ( +from embodichain.gen_sim.prompt2scene.agent_tools.tools.table_clutter_fit import ( fit_table_to_clutter, ) +from embodichain.gen_sim.prompt2scene.utils.io import relative_path, write_json from embodichain.gen_sim.prompt2scene.utils.log import log_info, log_warning __all__ = ["fit_image_scene_table", "fit_text_scene_table"] +def _table_fit_output_paths( + *, + output_dir: Path, + object_ids: list[str], +) -> tuple[Path, dict[str, Path]]: + return ( + output_dir / "table_fit_to_clutter.glb", + { + object_id: output_dir / f"{object_id}_on_table.glb" + for object_id in object_ids + }, + ) + + def fit_text_scene_table( *, table_result: dict[str, Any], @@ -37,13 +52,30 @@ def fit_text_scene_table( ) -> dict[str, Any]: """Fit the text-scene table and convert failures to result data.""" try: - result = fit_table_to_clutter( + object_ids = [ + str(item["id"]) + for item in clutter_layout_result.get("objects", []) + if isinstance(item, dict) and item.get("id") and item.get("status") == "ok" + ] + table_output_path, object_output_paths = _table_fit_output_paths( + output_dir=output_dir, + object_ids=object_ids, + ) + manifest = fit_table_to_clutter( table_result=table_result, clutter_result=clutter_layout_result, output_root=output_root, output_dir=output_dir, + table_output_path=table_output_path, + object_output_paths=object_output_paths, object_coverage_percent=table_result.get("object_coverage_percent"), ) + manifest_path = output_dir / "table_fit_to_clutter_manifest.json" + write_json(manifest_path, manifest) + result = { + "status": "ok", + "manifest_path": relative_path(str(manifest_path), output_root), + } log_info(f"text table fit completed status={result.get('status')}") return result except Exception as exc: @@ -76,6 +108,17 @@ def fit_image_scene_table( } try: + object_ids = [ + str(item["id"]) + for item in generated_objects + if isinstance(item, dict) + and item.get("id") + and item.get("aligned_geometry_path") + ] + table_output_path, object_output_paths = _table_fit_output_paths( + output_dir=output_dir, + object_ids=object_ids, + ) clutter_result = { "clutter_2d_aabb_cm": alignment_result.get( "final_clutter_2d_aabb_cm" @@ -90,13 +133,21 @@ def fit_image_scene_table( if item.get("id") and item.get("aligned_geometry_path") ], } - result = fit_table_to_clutter( + manifest = fit_table_to_clutter( table_result=generated_table, clutter_result=clutter_result, output_root=output_root, output_dir=output_dir, + table_output_path=table_output_path, + object_output_paths=object_output_paths, object_coverage_percent=generated_table.get("object_coverage_percent"), ) + manifest_path = output_dir / "table_fit_to_clutter_manifest.json" + write_json(manifest_path, manifest) + result = { + "status": "ok", + "manifest_path": relative_path(str(manifest_path), output_root), + } log_info(f"image table fit completed status={result.get('status')}") return result except Exception as exc: diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/tools/text_clutter_layout.py b/embodichain/gen_sim/prompt2scene/agent_tools/tools/text_clutter_layout.py index 80bc32100..c77f7f866 100644 --- a/embodichain/gen_sim/prompt2scene/agent_tools/tools/text_clutter_layout.py +++ b/embodichain/gen_sim/prompt2scene/agent_tools/tools/text_clutter_layout.py @@ -21,7 +21,7 @@ from typing import Any from embodichain.gen_sim.prompt2scene.utils.log import log_info, log_warning -from embodichain.gen_sim.prompt2scene.agent_tools.managers.text_layout_manager import ( +from embodichain.gen_sim.prompt2scene.agent_tools.tools.text_object_settle import ( settle_text_objects_to_ground, ) diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/managers/text_layout_manager/settle.py b/embodichain/gen_sim/prompt2scene/agent_tools/tools/text_object_settle.py similarity index 84% rename from embodichain/gen_sim/prompt2scene/agent_tools/managers/text_layout_manager/settle.py rename to embodichain/gen_sim/prompt2scene/agent_tools/tools/text_object_settle.py index da3cdde6e..b5acfa121 100644 --- a/embodichain/gen_sim/prompt2scene/agent_tools/managers/text_layout_manager/settle.py +++ b/embodichain/gen_sim/prompt2scene/agent_tools/tools/text_object_settle.py @@ -30,18 +30,11 @@ from embodichain.gen_sim.prompt2scene.agent_tools.managers.simulation_manager.schemas import ( GravityDropRequest, ) -from embodichain.gen_sim.prompt2scene.agent_tools.managers.optimization_manager import ( - _object_scenes_xy_aabb_manifest, +from embodichain.gen_sim.prompt2scene.agent_tools.managers.layout_manager import ( + LayoutManager, ) -from embodichain.gen_sim.prompt2scene.agent_tools.managers.geometry_manager.scene_geometry import ( - _aabb_bottom_to_xy_plane_transform, - _copy_scene_with_transform, - _matrix_from_json, - _scale_transform, - _scene_to_mesh, - _xy_aabb_center, - _xy_aabb_size, - _z_up_to_glb_y_up_transform, +from embodichain.gen_sim.prompt2scene.agent_tools.managers.geometry_manager import ( + GeometryManager, ) from embodichain.gen_sim.prompt2scene.utils.io import ( relative_path, @@ -52,9 +45,6 @@ MatplotlibManager, RenderFootprintLayoutRequest, ) -from embodichain.gen_sim.prompt2scene.agent_tools.managers.text_layout_manager.layout import ( - _layout_text_objects_grid, -) __all__ = ["settle_text_objects_to_ground"] @@ -89,7 +79,7 @@ def settle_text_objects_to_ground( output_dir = output_dir.expanduser().resolve() output_dir.mkdir(parents=True, exist_ok=True) sim = SimulationManager(headless=True, sim_device=sim_device) - z_to_y = _z_up_to_glb_y_up_transform() + z_to_y = GeometryManager.z_up_to_glb_y_up_transform() y_to_z = np.linalg.inv(z_to_y) settled_objects: list[dict[str, Any]] = [] @@ -144,21 +134,29 @@ def settle_text_objects_to_ground( try: # Load simready (GLB Y-up) → convert to internal Z-up scene_yup = trimesh.load(simready_path, force="scene") - scene = _copy_scene_with_transform(scene_yup, y_to_z) + scene = GeometryManager.copy_scene_with_transform(scene_yup, y_to_z) # Apply real-world scale - scale_transform = _scale_transform(scale_factor) + scale_transform = GeometryManager.scale_transform(scale_factor) scene.apply_transform(scale_transform) # Settle the object under gravity. - mesh = _scene_to_mesh(scene, trimesh=trimesh) + mesh = GeometryManager.scene_to_mesh(scene, trimesh=trimesh) mesh_bounds = np.asarray(mesh.bounds, dtype=np.float64) mesh_z_height = max(float(mesh_bounds[1][2] - mesh_bounds[0][2]), 0.0) - bottom_to_xy = _aabb_bottom_to_xy_plane_transform(mesh_bounds) - normalized_scene = _copy_scene_with_transform(scene, bottom_to_xy) + bottom_to_xy = GeometryManager.aabb_bottom_to_xy_plane_transform( + mesh_bounds + ) + normalized_scene = GeometryManager.copy_scene_with_transform( + scene, + bottom_to_xy, + ) # Export to Y-up GLB for gravity - pre_gravity_scene = _copy_scene_with_transform(normalized_scene, z_to_y) + pre_gravity_scene = GeometryManager.copy_scene_with_transform( + normalized_scene, + z_to_y, + ) pre_gravity_path = tmp_path / f"{obj_id}_pre_gravity.glb" pre_gravity_scene.export(pre_gravity_path) gravity_initial_height = mesh_z_height * 0.1 @@ -174,7 +172,7 @@ def settle_text_objects_to_ground( initial_height=gravity_initial_height, ) ) - gravity_transform = _matrix_from_json( + gravity_transform = GeometryManager.matrix_from_json( gravity_result.final_pose, name=f"{obj_id}.gravity_final_pose", ) @@ -183,16 +181,19 @@ def settle_text_objects_to_ground( gravity_reason = traceback.format_exc() # Apply gravity result (in internal Z-up space) - settled_scene = _copy_scene_with_transform( + settled_scene = GeometryManager.copy_scene_with_transform( normalized_scene, gravity_transform, ) # Center the bottom of the AABB at the XY origin. - settled_mesh = _scene_to_mesh(settled_scene, trimesh=trimesh) + settled_mesh = GeometryManager.scene_to_mesh( + settled_scene, + trimesh=trimesh, + ) settled_bounds = np.asarray(settled_mesh.bounds, dtype=np.float64) - settled_xy_center = _xy_aabb_center(settled_bounds) - settled_xy_size = _xy_aabb_size(settled_bounds) + settled_xy_center = GeometryManager.xy_aabb_center(settled_bounds) + settled_xy_size = GeometryManager.xy_aabb_size(settled_bounds) settled_bottom_z = float(settled_bounds[0, 2]) centre_transform = np.eye(4, dtype=np.float64) @@ -201,19 +202,22 @@ def settle_text_objects_to_ground( -float(settled_xy_center[1]), -settled_bottom_z, ] - centred_scene = _copy_scene_with_transform( + centred_scene = GeometryManager.copy_scene_with_transform( settled_scene, centre_transform, ) # Verify final bounds - centred_mesh = _scene_to_mesh(centred_scene, trimesh=trimesh) + centred_mesh = GeometryManager.scene_to_mesh( + centred_scene, + trimesh=trimesh, + ) centred_bounds = np.asarray(centred_mesh.bounds, dtype=np.float64) - centred_xy_size = _xy_aabb_size(centred_bounds) + centred_xy_size = GeometryManager.xy_aabb_size(centred_bounds) # Export settled GLB (Z-up → Y-up for GLB output) settled_glb_path = output_dir / f"{obj_id}_settled.glb" - _copy_scene_with_transform(centred_scene, z_to_y).export( + GeometryManager.copy_scene_with_transform(centred_scene, z_to_y).export( settled_glb_path ) @@ -261,13 +265,15 @@ def settle_text_objects_to_ground( if object_scenes: xy_sizes = { oid: np.asarray( - _xy_aabb_size(_scene_to_mesh(scene, trimesh=trimesh).bounds), + GeometryManager.xy_aabb_size( + GeometryManager.scene_to_mesh(scene, trimesh=trimesh).bounds + ), dtype=np.float64, ) for oid, scene in object_scenes } relations = list(spatial_relations or []) - layout_result = _layout_text_objects_grid( + layout_result = LayoutManager.layout_text_objects_grid( object_ids=[oid for oid, _ in object_scenes], xy_sizes=xy_sizes, spatial_relations=relations, @@ -320,21 +326,27 @@ def settle_text_objects_to_ground( laid_out_scenes: list[tuple[str, Any]] = [] for oid, scene in object_scenes: target_xy = target_centers[oid] - settled_mesh = _scene_to_mesh(scene, trimesh=trimesh) + settled_mesh = GeometryManager.scene_to_mesh(scene, trimesh=trimesh) settled_bounds = np.asarray(settled_mesh.bounds, dtype=np.float64) - current_xy = _xy_aabb_center(settled_bounds) + current_xy = GeometryManager.xy_aabb_center(settled_bounds) placement = np.eye(4, dtype=np.float64) placement[:3, 3] = [ float(target_xy[0] - current_xy[0]), float(target_xy[1] - current_xy[1]), 0.0, ] - laid_out_scene = _copy_scene_with_transform(scene, placement) + laid_out_scene = GeometryManager.copy_scene_with_transform( + scene, + placement, + ) laid_out_scenes.append((oid, laid_out_scene)) # Export laid-out GLB (replaces the origin-centred one) laid_out_glb_path = output_dir / f"{oid}_laid_out.glb" - _copy_scene_with_transform(laid_out_scene, z_to_y).export(laid_out_glb_path) + GeometryManager.copy_scene_with_transform( + laid_out_scene, + z_to_y, + ).export(laid_out_glb_path) # Update per-object metadata with layout position for item in settled_objects: @@ -345,17 +357,20 @@ def settle_text_objects_to_ground( str(laid_out_glb_path), output_root ) laid_out_bounds = np.asarray( - _scene_to_mesh(laid_out_scene, trimesh=trimesh).bounds, + GeometryManager.scene_to_mesh( + laid_out_scene, + trimesh=trimesh, + ).bounds, dtype=np.float64, ) item["laid_out_xy_size_cm"] = ( - _xy_aabb_size(laid_out_bounds) * 100.0 + GeometryManager.xy_aabb_size(laid_out_bounds) * 100.0 ).tolist() break object_scenes = laid_out_scenes - clutter_2d_aabb_cm = _object_scenes_xy_aabb_manifest( + clutter_2d_aabb_cm = LayoutManager.object_scenes_xy_aabb_manifest( object_scenes=object_scenes, trimesh=trimesh, unit_scale=100.0, diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/tools/text_scene_metric_scale.py b/embodichain/gen_sim/prompt2scene/agent_tools/tools/text_scene_metric_scale.py index fd0b13835..b0ceb0cae 100644 --- a/embodichain/gen_sim/prompt2scene/agent_tools/tools/text_scene_metric_scale.py +++ b/embodichain/gen_sim/prompt2scene/agent_tools/tools/text_scene_metric_scale.py @@ -20,7 +20,7 @@ from pathlib import Path from typing import Any -from embodichain.gen_sim.prompt2scene.agent_tools.managers.metric_scale_manager import ( +from embodichain.gen_sim.prompt2scene.agent_tools.managers.simready_manager import ( METRIC_SCALE_ENABLED, EstimateMetricScalesRequest, MetricScaleManager, diff --git a/embodichain/gen_sim/prompt2scene/workflows/llm_output.py b/embodichain/gen_sim/prompt2scene/llms/llm_output.py similarity index 92% rename from embodichain/gen_sim/prompt2scene/workflows/llm_output.py rename to embodichain/gen_sim/prompt2scene/llms/llm_output.py index bcc98bcbb..07706a11c 100644 --- a/embodichain/gen_sim/prompt2scene/workflows/llm_output.py +++ b/embodichain/gen_sim/prompt2scene/llms/llm_output.py @@ -17,14 +17,8 @@ from __future__ import annotations import json -from pathlib import Path from typing import Any, Callable -from embodichain.gen_sim.prompt2scene.workflows.artifact_writer import ( - WorkflowArtifactWriter, - write_next_raw_model_output, -) - __all__ = [ "bind_structured_output", "coerce_json_object_output", @@ -114,11 +108,7 @@ def call_structured_json_model_step( schema: dict[str, Any], messages: list[dict[str, Any]], context: str, - step_name: str, - output_root: Path | None, attempt_count: int, - raw_output_label: str | None = None, - artifact_writer: WorkflowArtifactWriter | None = None, raw_output_writer: Callable[[dict[str, Any]], None] | None = None, ) -> dict[str, Any]: """Call a structured-output model, validate JSON, and persist raw output.""" @@ -142,18 +132,6 @@ def call_structured_json_model_step( if raw_output_writer is not None: raw_output_writer(raw_model_output) - elif artifact_writer is not None: - artifact_writer.write_next_raw_model_output( - payload=raw_model_output, - label=raw_output_label, - ) - elif output_root is not None: - write_next_raw_model_output( - output_root=output_root, - step_name=step_name, - payload=raw_model_output, - label=raw_output_label, - ) return raw_model_output diff --git a/embodichain/gen_sim/prompt2scene/pipeline/runner.py b/embodichain/gen_sim/prompt2scene/pipeline/runner.py index 7931f00ba..90d788c3e 100644 --- a/embodichain/gen_sim/prompt2scene/pipeline/runner.py +++ b/embodichain/gen_sim/prompt2scene/pipeline/runner.py @@ -24,15 +24,16 @@ InputKind, Prompt2SceneInput, ) -from embodichain.gen_sim.prompt2scene.workflows.artifact_writer import ( +from embodichain.gen_sim.prompt2scene.workflows.paths import ( IMAGE_SEGMENTS_STEP, IMAGE_SPATIAL_RELATIONS_STEP, SCENE_INTAKE_STEP, - STEP_RESULT_FILENAME, - step_result_path, - write_step_result, TEXT_RELATIONS_STEP, UNIFIED_SCENE_STEP, + PipelinePaths, +) +from embodichain.gen_sim.prompt2scene.workflows.artifact_writer import ( + write_step_result, ) from embodichain.gen_sim.prompt2scene.workflows.unified_scene.graph import ( run_unified_scene, @@ -40,7 +41,7 @@ from embodichain.gen_sim.prompt2scene.workflows.unified_scene_gen.graph import ( run_unified_scene_gen, ) -from embodichain.gen_sim.prompt2scene.agent_tools.tools.gym_export import ( +from embodichain.gen_sim.prompt2scene.workflows.gym_export import ( export_gym_config, ) from embodichain.gen_sim.prompt2scene.utils.io import write_json @@ -118,6 +119,7 @@ def run_prompt2scene( f"input_kind={request.input_kind.value} output_root={request.output_root}" ) request.output_root.mkdir(parents=True, exist_ok=True) + paths = PipelinePaths(request.output_root) manifest_path = request.output_root / INPUT_MANIFEST_FILENAME manifest = request.to_manifest() if llm_cfg is not None: @@ -149,8 +151,7 @@ def run_prompt2scene( llm_cfg=llm_cfg, output_root=request.output_root, ) - image_segments_path = step_result_path( - request.output_root, + image_segments_path = paths.step_result( IMAGE_SEGMENTS_STEP, ) if not image_segments_path.is_file(): @@ -159,8 +160,7 @@ def run_prompt2scene( IMAGE_SEGMENTS_STEP, image_relations.to_segmentation_manifest(), ) - image_spatial_relations_path = step_result_path( - request.output_root, + image_spatial_relations_path = paths.step_result( IMAGE_SPATIAL_RELATIONS_STEP, ) if not image_spatial_relations_path.is_file(): @@ -180,8 +180,7 @@ def run_prompt2scene( image_relations=image_relations, output_root=request.output_root, ) - unified_scene_path = step_result_path( - request.output_root, + unified_scene_path = paths.step_result( UNIFIED_SCENE_STEP, ) else: @@ -192,8 +191,7 @@ def run_prompt2scene( llm_cfg=llm_cfg, output_root=request.output_root, ) - text_relations_path = step_result_path( - request.output_root, + text_relations_path = paths.step_result( TEXT_RELATIONS_STEP, ) log.log_info( @@ -206,8 +204,7 @@ def run_prompt2scene( text_relations=text_relations, output_root=request.output_root, ) - unified_scene_path = step_result_path( - request.output_root, + unified_scene_path = paths.step_result( UNIFIED_SCENE_STEP, ) log.log_info( diff --git a/embodichain/gen_sim/prompt2scene/prompts/__init__.py b/embodichain/gen_sim/prompt2scene/prompts/__init__.py index f72a97f6d..f772b5399 100644 --- a/embodichain/gen_sim/prompt2scene/prompts/__init__.py +++ b/embodichain/gen_sim/prompt2scene/prompts/__init__.py @@ -21,7 +21,12 @@ default_prompt_renderer = PromptRenderer(data) -__all__ = ["load_prompt", "load_prompt_data", "render_prompt", "default_prompt_renderer"] +__all__ = [ + "load_prompt", + "load_prompt_data", + "render_prompt", + "default_prompt_renderer", +] def load_prompt(prompt_name: str) -> str: diff --git a/embodichain/gen_sim/prompt2scene/prompts/builders.py b/embodichain/gen_sim/prompt2scene/prompts/builders.py new file mode 100644 index 000000000..8596c32dc --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/prompts/builders.py @@ -0,0 +1,394 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +import json +from pathlib import Path +from typing import Any + +from embodichain.gen_sim.prompt2scene.prompts import render_prompt +from embodichain.gen_sim.prompt2scene.utils.io import image_to_data_url + +__all__ = [ + "build_filter_extra_instances_messages", + "build_image_metric_scale_messages", + "build_scene_intake_messages", + "build_scene_intake_verifier_messages", + "build_spatial_layout_messages", + "build_text_metric_scale_messages", + "build_text_relation_messages", + "build_up_down_flip_check_messages", +] + + +SCENE_INTAKE_PROMPT = "scene_intake.yaml" +IMAGE_RELATIONS_PROMPT = "image_relations.yaml" +TEXT_RELATIONS_PROMPT = "text_relations.yaml" +UNIFIED_SCENE_GEN_PROMPT = "unified_scene_gen.yaml" + + + +def build_scene_intake_messages(request: Prompt2SceneInput) -> list[dict[str, Any]]: + """Build LangChain-compatible messages for scene intake.""" + + from embodichain.gen_sim.prompt2scene.workflows.request import InputKind + + if request.input_kind == InputKind.TEXT: + return [ + { + "role": "system", + "content": render_prompt( + SCENE_INTAKE_PROMPT, prompt_key="text_system" + ), + }, + { + "role": "user", + "content": render_prompt( + SCENE_INTAKE_PROMPT, + {"text": request.text or ""}, + prompt_key="text_user", + ), + }, + ] + return [ + { + "role": "system", + "content": render_prompt( + SCENE_INTAKE_PROMPT, prompt_key="image_system" + ), + }, + { + "role": "user", + "content": [ + { + "type": "text", + "text": render_prompt( + SCENE_INTAKE_PROMPT, prompt_key="image_user" + ), + }, + { + "type": "image_url", + "image_url": {"url": image_to_data_url(request.image_path)}, + }, + ], + }, + ] + + +def build_scene_intake_verifier_messages( + *, + request: Prompt2SceneInput, + scene_intake: SceneIntakeSpec, +) -> list[dict[str, Any]]: + """Build messages for scene-intake group and count verification.""" + + from embodichain.gen_sim.prompt2scene.workflows.request import InputKind + + table_draft: dict[str, object] = { + "name": scene_intake.table.name, + "description": scene_intake.table.description, + "complete_table_description": ( + scene_intake.table.complete_table_description + ), + "is_complete_visible_table": scene_intake.table.is_complete_visible_table, + "class_candidate": list(scene_intake.table.class_candidate), + } + if scene_intake.table.object_coverage_percent is not None: + table_draft["object_coverage_percent"] = ( + scene_intake.table.object_coverage_percent + ) + scene_intake_json = json.dumps( + { + "table": table_draft, + "assets": [ + { + "name": asset.name, + "description": asset.description, + "class_candidate": list(asset.class_candidate), + "count": asset.count, + } + for asset in scene_intake.assets + ], + }, + ensure_ascii=False, + indent=2, + ) + + if request.input_kind == InputKind.TEXT: + return [ + { + "role": "system", + "content": render_prompt( + SCENE_INTAKE_PROMPT, prompt_key="verifier_system" + ), + }, + { + "role": "user", + "content": render_prompt( + SCENE_INTAKE_PROMPT, + { + "text": request.text or "", + "scene_intake_json": scene_intake_json, + }, + prompt_key="verifier_text_user", + ), + }, + ] + + image_path = request.image_path + if image_path is None: + raise ValueError("Image input requires image_path.") + return [ + { + "role": "system", + "content": render_prompt( + SCENE_INTAKE_PROMPT, prompt_key="verifier_system" + ), + }, + { + "role": "user", + "content": [ + { + "type": "text", + "text": render_prompt( + SCENE_INTAKE_PROMPT, + {"scene_intake_json": scene_intake_json}, + prompt_key="verifier_image_user", + ), + }, + { + "type": "image_url", + "image_url": {"url": image_to_data_url(image_path)}, + }, + ], + }, + ] + + + + +def build_filter_extra_instances_messages( + *, + debug_image_path: Path, + name: str, + description: str, + expected_count: int, + class_candidate: list[str], +) -> list[dict[str, Any]]: + """Build LangChain-compatible messages for VLM extra-mask filtering.""" + return [ + { + "role": "system", + "content": render_prompt( + IMAGE_RELATIONS_PROMPT, prompt_key="filter_extra_instances_system" + ), + }, + { + "role": "user", + "content": [ + { + "type": "text", + "text": render_prompt( + IMAGE_RELATIONS_PROMPT, + { + "name": name.replace("_", " "), + "description": description, + "expected_count": str(expected_count), + "class_candidate": ", ".join( + candidate.replace("_", " ") + for candidate in class_candidate + ), + }, + prompt_key="filter_extra_instances_user", + ), + }, + { + "type": "image_url", + "image_url": {"url": image_to_data_url(debug_image_path)}, + }, + ], + }, + ] + + +def build_spatial_layout_messages( + *, + bbox_name_image_path: Path, + asset_ids: list[str], +) -> list[dict[str, Any]]: + """Build messages for VLM spatial ordering and object-state extraction.""" + return [ + { + "role": "system", + "content": render_prompt( + IMAGE_RELATIONS_PROMPT, prompt_key="spatial_layout_system" + ), + }, + { + "role": "user", + "content": [ + { + "type": "text", + "text": render_prompt( + IMAGE_RELATIONS_PROMPT, + { + "asset_ids": "\n".join( + f"- {asset_id}" for asset_id in asset_ids + ), + }, + prompt_key="spatial_layout_user", + ), + }, + { + "type": "image_url", + "image_url": {"url": image_to_data_url(bbox_name_image_path)}, + }, + ], + }, + ] + + + + +def build_text_relation_messages( + *, + request: Prompt2SceneInput, + scene_intake: SceneIntakeSpec, +) -> list[dict[str, Any]]: + """Build messages for explicit text spatial-relation extraction.""" + asset_names = "\n".join(f"- {asset.name}" for asset in scene_intake.assets) + return [ + { + "role": "system", + "content": render_prompt(TEXT_RELATIONS_PROMPT, prompt_key="system"), + }, + { + "role": "user", + "content": render_prompt( + TEXT_RELATIONS_PROMPT, + { + "asset_names": asset_names, + "text": request.text or "", + }, + prompt_key="user", + ), + }, + ] + + + + +def build_image_metric_scale_messages( + *, + bbox_name_image_path: Path, + objects_json: list[dict[str, Any]], +) -> list[dict[str, Any]]: + """Build messages for image-scene object metric scale estimation.""" + return [ + { + "role": "system", + "content": render_prompt( + UNIFIED_SCENE_GEN_PROMPT, prompt_key="image_metric_scale_system" + ), + }, + { + "role": "user", + "content": [ + { + "type": "text", + "text": render_prompt( + UNIFIED_SCENE_GEN_PROMPT, + { + "objects_json": json.dumps( + objects_json, ensure_ascii=False, indent=2 + ), + }, + prompt_key="image_metric_scale_user", + ), + }, + { + "type": "image_url", + "image_url": {"url": image_to_data_url(bbox_name_image_path)}, + }, + ], + }, + ] + + +def build_text_metric_scale_messages( + *, + user_text: str, + objects_json: list[dict[str, Any]], +) -> list[dict[str, Any]]: + """Build messages for text-scene object metric scale estimation.""" + return [ + { + "role": "system", + "content": render_prompt( + UNIFIED_SCENE_GEN_PROMPT, prompt_key="text_metric_scale_system" + ), + }, + { + "role": "user", + "content": render_prompt( + UNIFIED_SCENE_GEN_PROMPT, + { + "user_text": user_text, + "objects_json": json.dumps( + objects_json, ensure_ascii=False, indent=2 + ), + }, + prompt_key="text_metric_scale_user", + ), + }, + ] + + +def build_up_down_flip_check_messages( + *, + original_image_path: Path, + comparison_image_path: Path, +) -> list[dict[str, Any]]: + """Build messages for VLM support-normal up/down flip verification.""" + return [ + { + "role": "system", + "content": render_prompt( + UNIFIED_SCENE_GEN_PROMPT, prompt_key="up_down_flip_check_system" + ), + }, + { + "role": "user", + "content": [ + { + "type": "text", + "text": render_prompt( + UNIFIED_SCENE_GEN_PROMPT, + prompt_key="up_down_flip_check_user", + ), + }, + { + "type": "image_url", + "image_url": {"url": image_to_data_url(original_image_path)}, + }, + { + "type": "image_url", + "image_url": {"url": image_to_data_url(comparison_image_path)}, + }, + ], + }, + ] diff --git a/embodichain/gen_sim/prompt2scene/prompts/data/text_relations.yaml b/embodichain/gen_sim/prompt2scene/prompts/data/text_relations.yaml index 7a267d091..fff1f1fa2 100644 --- a/embodichain/gen_sim/prompt2scene/prompts/data/text_relations.yaml +++ b/embodichain/gen_sim/prompt2scene/prompts/data/text_relations.yaml @@ -61,14 +61,39 @@ system: | leaning against another object. + + These examples are part of the rules — follow the mapping exactly. + + Text: "A is to the left of B" + Output: {"subject": "A", "relation": "left_of", "object": "B", + "evidence": "The text says A is to the left of B."} + + Text: "A is to the right of B" ← NOTE: A is right of B means B is left of A + Output: {"subject": "B", "relation": "left_of", "object": "A", + "evidence": "The text says A is to the right of B."} + + Text: "A is in front of B" + Output: {"subject": "A", "relation": "front_of", "object": "B", + "evidence": "The text says A is in front of B."} + + Text: "behind the computer, there is a mouse" ← mouse is behind computer means computer is front_of mouse + Output: {"subject": "computer", "relation": "front_of", "object": "mouse", + "evidence": "The text says behind the computer there is a mouse."} + + Text: "the mouse is at the right side of the computer" + The mouse is RIGHT of the computer, so the computer is LEFT of the mouse: + Output: {"subject": "computer", "relation": "left_of", "object": "mouse", + "evidence": "The text says the mouse is at the right side of the computer."} + + { "object_relations": [ { - "subject": "paper_cup", + "subject": "computer", "relation": "left_of", - "object": "plate", - "evidence": "The text says the paper cup is left of the plate." + "object": "mouse", + "evidence": "The text says the mouse is at the right side of the computer." } ], "table_constraints": [ diff --git a/embodichain/gen_sim/prompt2scene/prompts/schemas.py b/embodichain/gen_sim/prompt2scene/prompts/schemas.py new file mode 100644 index 000000000..20d617962 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/prompts/schemas.py @@ -0,0 +1,354 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +"""JSON schemas for LLM structured-output calls across all workflows.""" + +from __future__ import annotations + +from typing import Any + +from embodichain.gen_sim.prompt2scene.agent_tools.tools.spatial_relations import ( + GRID_VALUE_LIST, + RELATION_VALUE_LIST, +) + +__all__ = [ + "FILTER_EXTRA_INSTANCES_JSON_SCHEMA", + "IMAGE_METRIC_SCALE_JSON_SCHEMA", + "SCENE_INTAKE_JSON_SCHEMA", + "SPATIAL_LAYOUT_JSON_SCHEMA", + "TEXT_RELATIONS_JSON_SCHEMA", + "UP_DOWN_FLIP_CHECK_JSON_SCHEMA", +] + + +SCENE_INTAKE_JSON_SCHEMA: dict[str, Any] = { + "title": "SceneIntakeModelOutput", + "description": ( + "Objects and table information extracted from a text or image input." + ), + "type": "object", + "additionalProperties": False, + "properties": { + "table": { + "type": "object", + "additionalProperties": False, + "properties": { + "name": { + "type": "string", + "description": ( + "Canonical English class name for the visible table " + "or tabletop target, such as table, desk, dining_table, " + "coffee_table, workbench, or tabletop." + ), + }, + "description": { + "type": "string", + "minLength": 20, + "maxLength": 180, + "description": ( + "One concise standalone appearance description of the " + "visible table or tabletop region." + ), + }, + "complete_table_description": { + "type": "string", + "minLength": 20, + "maxLength": 220, + "description": ( + "One concise standalone description of a complete table " + "asset for text-to-3D generation, matching the visible " + "tabletop color, material, and texture." + ), + }, + "is_complete_visible_table": { + "type": "boolean", + "description": ( + "For image input, whether a mostly complete table is " + "visible and suitable as the final table geometry source. " + "For text input, this should be false." + ), + }, + "class_candidate": { + "type": "array", + "minItems": 5, + "maxItems": 5, + "description": ( + "Exactly five likely class names for segmenting the " + "visible table or tabletop target." + ), + "items": { + "type": "string", + "minLength": 1, + }, + }, + "object_coverage_percent": { + "type": "integer", + "enum": [10, 30, 50, 70], + "description": ( + "For image input with a complete visible table ONLY: " + "choose the closest coverage bucket for objects on the " + "tabletop: 10 (mostly empty, a few small objects), " + "30 (lightly cluttered), 50 (moderately cluttered), " + "70 (densely packed). Omit this field entirely for " + "text input or when is_complete_visible_table is false." + ), + }, + }, + "required": [ + "name", + "description", + "complete_table_description", + "is_complete_visible_table", + "class_candidate", + ], + }, + "assets": { + "type": "array", + "description": ( + "Object category groups on or intended for the tabletop scene." + ), + "items": { + "type": "object", + "additionalProperties": False, + "properties": { + "name": { + "type": "string", + "description": ( + "Canonical English object name, singular, " + "snake_case preferred." + ), + }, + "description": { + "type": "string", + "minLength": 20, + "maxLength": 180, + "description": ( + "One concise appearance description of the object for " + "image and 3D geometry generation." + ), + }, + "class_candidate": { + "type": "array", + "minItems": 5, + "maxItems": 5, + "description": ( + "Exactly five likely object class names for later " + "image detection or segmentation." + ), + "items": { + "type": "string", + "minLength": 1, + }, + }, + "count": { + "type": "integer", + "description": ( + "Number of repeated instances in this object category " + "group. Only group objects that can share the same name, " + "description, and class_candidate list." + ), + "minimum": 1, + }, + }, + "required": ["name", "description", "class_candidate", "count"], + }, + }, + }, + "required": ["table", "assets"], +} + + +FILTER_EXTRA_INSTANCES_JSON_SCHEMA: dict[str, Any] = { + "title": "FilterExtraImageInstancesOutput", + "type": "object", + "additionalProperties": False, + "properties": { + "extra_instance_numbers": { + "type": "array", + "description": "1-based mask numbers that should be removed.", + "items": {"type": "integer", "minimum": 1}, + }, + "reason": { + "type": "string", + "description": "Brief reason for the removal decision.", + }, + }, + "required": ["extra_instance_numbers", "reason"], +} + +SPATIAL_LAYOUT_JSON_SCHEMA: dict[str, Any] = { + "title": "ImageSpatialLayoutOutput", + "type": "object", + "additionalProperties": False, + "properties": { + "anchor": { + "type": "object", + "additionalProperties": False, + "properties": { + "asset_id": {"type": "string", "minLength": 1}, + "grid": { + "type": "string", + "enum": GRID_VALUE_LIST, + }, + "reason": {"type": "string"}, + }, + "required": ["asset_id", "grid", "reason"], + }, + "x_order": { + "type": "array", + "description": "Asset-id groups ordered from left to right.", + "items": { + "type": "array", + "items": {"type": "string", "minLength": 1}, + "minItems": 1, + }, + "minItems": 1, + }, + "y_order": { + "type": "array", + "description": "Asset-id groups ordered from front to back.", + "items": { + "type": "array", + "items": {"type": "string", "minLength": 1}, + "minItems": 1, + }, + "minItems": 1, + }, + "asset_states": { + "type": "array", + "items": { + "type": "object", + "additionalProperties": True, + "properties": { + "asset_id": {"type": "string", "minLength": 1}, + "is_arbitrary_layout": {"type": "boolean"}, + "reason": {"type": "string", "minLength": 1}, + }, + "required": [ + "asset_id", + "is_arbitrary_layout", + "reason", + ], + }, + }, + }, + "required": ["anchor", "x_order", "y_order", "asset_states"], +} + + +TEXT_RELATIONS_JSON_SCHEMA: dict[str, Any] = { + "title": "TextRelationsOutput", + "type": "object", + "additionalProperties": False, + "properties": { + "object_relations": { + "type": "array", + "items": { + "type": "object", + "additionalProperties": False, + "properties": { + "subject": {"type": "string", "minLength": 1}, + "relation": { + "type": "string", + "enum": RELATION_VALUE_LIST, + }, + "object": {"type": "string", "minLength": 1}, + "evidence": {"type": "string", "minLength": 1}, + }, + "required": ["subject", "relation", "object", "evidence"], + }, + }, + "table_constraints": { + "type": "array", + "items": { + "type": "object", + "additionalProperties": False, + "properties": { + "asset": {"type": "string", "minLength": 1}, + "grid": { + "type": "string", + "enum": GRID_VALUE_LIST, + }, + "evidence": {"type": "string", "minLength": 1}, + }, + "required": ["asset", "grid", "evidence"], + }, + }, + "object_layouts": { + "type": "array", + "items": { + "type": "object", + "additionalProperties": False, + "properties": { + "asset": {"type": "string", "minLength": 1}, + "is_arbitrary_layout": {"type": "boolean"}, + "reason": {"type": "string", "minLength": 1}, + }, + "required": ["asset", "is_arbitrary_layout", "reason"], + }, + }, + }, + "required": ["object_relations", "table_constraints", "object_layouts"], +} + + +UP_DOWN_FLIP_CHECK_JSON_SCHEMA: dict[str, Any] = { + "title": "AlignedUpDownFlipCheckOutput", + "type": "object", + "additionalProperties": False, + "properties": { + "selected_number": {"type": "integer", "enum": [1, 2]}, + "confidence": {"type": "number", "minimum": 0.0, "maximum": 1.0}, + "reason": {"type": "string"}, + }, + "required": ["selected_number", "confidence", "reason"], +} + +IMAGE_METRIC_SCALE_JSON_SCHEMA: dict[str, Any] = { + "title": "ImageMetricScaleEstimate", + "type": "object", + "additionalProperties": False, + "properties": { + "object_scales": { + "type": "array", + "items": { + "type": "object", + "additionalProperties": False, + "properties": { + "object_id": {"type": "string"}, + "bbox_dims_cm": { + "type": "array", + "minItems": 3, + "maxItems": 3, + "items": { + "type": "number", + "minimum": 1.0e-6, + }, + }, + "confidence": { + "type": "number", + "minimum": 0.0, + "maximum": 1.0, + }, + "reason": {"type": "string"}, + }, + "required": ["object_id", "bbox_dims_cm", "confidence", "reason"], + }, + }, + }, + "required": ["object_scales"], +} diff --git a/embodichain/gen_sim/prompt2scene/workflows/artifact_writer.py b/embodichain/gen_sim/prompt2scene/workflows/artifact_writer.py index 6587ccbbc..c535a9701 100644 --- a/embodichain/gen_sim/prompt2scene/workflows/artifact_writer.py +++ b/embodichain/gen_sim/prompt2scene/workflows/artifact_writer.py @@ -17,10 +17,26 @@ from __future__ import annotations from pathlib import Path -import re from typing import Any from embodichain.gen_sim.prompt2scene.utils.io import write_json +from embodichain.gen_sim.prompt2scene.workflows.paths import ( + DEBUG_DIRNAME, + IMAGE_SEGMENTS_STEP, + IMAGE_SPATIAL_RELATIONS_STEP, + RAW_MODEL_OUTPUT_FILENAME, + SCENE_INTAKE_STEP, + STEP_RESULT_FILENAME, + TEXT_RELATIONS_STEP, + UNIFIED_SCENE_GEN_STEP, + UNIFIED_SCENE_STEP, + debug_dir_path, + debug_round_dir_path, + next_debug_round_dir_path, + next_debug_round_name, + step_dir_path, + step_result_path, +) __all__ = [ "DEBUG_DIRNAME", @@ -33,12 +49,6 @@ "UNIFIED_SCENE_GEN_STEP", "UNIFIED_SCENE_STEP", "WorkflowArtifactWriter", - "debug_dir_path", - "debug_round_dir_path", - "next_debug_round_dir_path", - "next_debug_round_name", - "step_dir_path", - "step_result_path", "write_debug_json", "write_debug_round_json", "write_next_raw_model_output", @@ -46,84 +56,12 @@ "write_step_result", ] -STEP_RESULT_FILENAME = "result.json" -DEBUG_DIRNAME = "debug" -RAW_MODEL_OUTPUT_FILENAME = "raw_model_output.json" - -SCENE_INTAKE_STEP = "scene_intake" -IMAGE_SEGMENTS_STEP = "image_segments" -IMAGE_SPATIAL_RELATIONS_STEP = "image_spatial_relations" -TEXT_RELATIONS_STEP = "text_relations" -UNIFIED_SCENE_STEP = "unified_scene" -UNIFIED_SCENE_GEN_STEP = "unified_scene_gen" - -DEBUG_ROUND_PATTERN = re.compile(r"^round_(\d+)(?:_|$)") - - -def step_dir_path(output_root: Path, step_name: str) -> Path: - """Return the directory path for a pipeline step.""" - return output_root / step_name - - -def step_result_path(output_root: Path, step_name: str) -> Path: - """Return the final result JSON path for a pipeline step.""" - return step_dir_path(output_root, step_name) / STEP_RESULT_FILENAME - - -def debug_dir_path(output_root: Path, step_name: str) -> Path: - """Return the debug directory path for a pipeline step.""" - return step_dir_path(output_root, step_name) / DEBUG_DIRNAME - - -def debug_round_dir_path( - output_root: Path, - step_name: str, - round_name: str, -) -> Path: - """Return a debug subdirectory path for one model/tool round.""" - return debug_dir_path(output_root, step_name) / round_name - - -def next_debug_round_name( - output_root: Path, - step_name: str, - label: str | None = None, -) -> str: - """Return the next step-local debug round name.""" - debug_dir = debug_dir_path(output_root, step_name) - max_index = 0 - if debug_dir.is_dir(): - for path in debug_dir.iterdir(): - if not path.is_dir(): - continue - match = DEBUG_ROUND_PATTERN.match(path.name) - if match is not None: - max_index = max(max_index, int(match.group(1))) - round_name = f"round_{max_index + 1:03d}" - if label: - round_name = f"{round_name}_{_path_token(label)}" - return round_name - - -def next_debug_round_dir_path( - output_root: Path, - step_name: str, - label: str | None = None, -) -> Path: - """Return the next step-local debug round directory path.""" - return debug_round_dir_path( - output_root, - step_name, - next_debug_round_name(output_root, step_name, label), - ) - def write_step_result( output_root: Path, step_name: str, payload: dict[str, Any], ) -> Path: - """Write a step's final result JSON and return its path.""" path = step_result_path(output_root, step_name) write_json(path, payload) return path @@ -136,7 +74,6 @@ def write_debug_json( filename: str, payload: dict[str, Any], ) -> Path: - """Write a debug JSON file under one step debug round.""" path = debug_round_dir_path(output_root, step_name, round_name) / filename write_json(path, payload) return path @@ -147,7 +84,6 @@ def write_debug_round_json( filename: str, payload: dict[str, Any], ) -> Path: - """Write a debug JSON file under an already selected debug round directory.""" path = debug_round_dir / filename write_json(path, payload) return path @@ -159,7 +95,6 @@ def write_raw_model_output( round_name: str, payload: dict[str, Any], ) -> Path: - """Write one raw structured model output under a step debug round.""" return write_debug_json( output_root, step_name, @@ -175,14 +110,11 @@ def write_next_raw_model_output( payload: dict[str, Any], label: str | None = None, ) -> Path: - """Write raw model output under the next step-local debug round.""" round_name = next_debug_round_name(output_root, step_name, label) return write_raw_model_output(output_root, step_name, round_name, payload) class WorkflowArtifactWriter: - """Write workflow artifacts under a fixed step directory.""" - def __init__(self, output_root: Path, step_name: str) -> None: self._output_root = output_root self._step_name = step_name @@ -208,19 +140,15 @@ def result_path(self) -> Path: return step_result_path(self._output_root, self._step_name) def next_debug_round_name(self, label: str | None = None) -> str: - """Return the next debug round name for this step.""" return next_debug_round_name(self._output_root, self._step_name, label) def next_debug_round_dir(self, label: str | None = None) -> Path: - """Return the next debug round directory for this step.""" return next_debug_round_dir_path(self._output_root, self._step_name, label) def debug_round_dir(self, round_name: str) -> Path: - """Return one debug round directory under this step.""" return debug_round_dir_path(self._output_root, self._step_name, round_name) def write_step_result(self, payload: dict[str, Any]) -> Path: - """Write the step's final result JSON.""" return write_step_result(self._output_root, self._step_name, payload) def write_debug_round_json( @@ -230,7 +158,6 @@ def write_debug_round_json( filename: str, payload: dict[str, Any], ) -> Path: - """Write a JSON artifact inside one named debug round.""" return write_debug_round_json( self.debug_round_dir(round_name), filename=filename, @@ -243,7 +170,6 @@ def write_raw_model_output( round_name: str, payload: dict[str, Any], ) -> Path: - """Write a raw model output into one named debug round.""" return write_raw_model_output( self._output_root, self._step_name, @@ -257,15 +183,9 @@ def write_next_raw_model_output( payload: dict[str, Any], label: str | None = None, ) -> Path: - """Write a raw model output into the next available debug round.""" return write_next_raw_model_output( self._output_root, self._step_name, payload, label=label, ) - - -def _path_token(value: str) -> str: - token = "".join(character if character.isalnum() else "_" for character in value) - return token.strip("_")[:80] or "round" diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/tools/gym_export.py b/embodichain/gen_sim/prompt2scene/workflows/gym_export.py similarity index 74% rename from embodichain/gen_sim/prompt2scene/agent_tools/tools/gym_export.py rename to embodichain/gen_sim/prompt2scene/workflows/gym_export.py index d26a14842..2a68c5eb6 100644 --- a/embodichain/gen_sim/prompt2scene/agent_tools/tools/gym_export.py +++ b/embodichain/gen_sim/prompt2scene/workflows/gym_export.py @@ -26,9 +26,9 @@ import numpy as np -from embodichain.gen_sim.prompt2scene.workflows.artifact_writer import ( - STEP_RESULT_FILENAME, +from embodichain.gen_sim.prompt2scene.workflows.paths import ( UNIFIED_SCENE_GEN_STEP, + PipelinePaths, ) __all__ = ["export_gym_config"] @@ -53,11 +53,6 @@ _DEFAULT_MAX_CONVEX_HULL_NUM = 32 -# --------------------------------------------------------------------------- -# helpers -# --------------------------------------------------------------------------- - - def _resolve_path(value: str, output_root: Path) -> Path: path = Path(value).expanduser() if path.is_absolute(): @@ -66,6 +61,10 @@ def _resolve_path(value: str, output_root: Path) -> Path: def _read_json(path: Path) -> dict[str, Any]: + if path.is_dir(): + raise IsADirectoryError(f"Expected JSON file but got directory: {path}") + if not path.is_file(): + raise FileNotFoundError(f"JSON file not found: {path}") with path.open("r", encoding="utf-8") as f: data = json.load(f) if not isinstance(data, dict): @@ -73,6 +72,34 @@ def _read_json(path: Path) -> dict[str, Any]: return data +def _resolve_table_fit_manifest_path( + *, + manifest_path_value: Any, + output_root: Path, + paths: PipelinePaths, +) -> Path: + if not manifest_path_value: + raise FileNotFoundError("table_fit_to_clutter manifest_path is missing or empty") + + resolved = _resolve_path(str(manifest_path_value), output_root) + if resolved.is_file(): + return resolved + + default_manifest = paths.table_fit_manifest + if default_manifest.is_file(): + return default_manifest + + if resolved.is_dir(): + raise IsADirectoryError( + "table_fit_to_clutter manifest_path points to a directory, not a JSON " + f"file: value={manifest_path_value!r} resolved={resolved}" + ) + raise FileNotFoundError( + "table_fit_to_clutter manifest_path does not point to a JSON file: " + f"value={manifest_path_value!r} resolved={resolved}" + ) + + def _matrix_to_euler_xyz_deg(matrix: list[list[float]]) -> list[float]: """Decompose a 3×3 or 4×4 rotation matrix into XYZ Euler angles (degrees).""" m = np.asarray(matrix, dtype=np.float64) @@ -188,11 +215,6 @@ def _rotated_aabb_offsets( ) -# --------------------------------------------------------------------------- -# consolidated object manifest -# --------------------------------------------------------------------------- - - def _build_object_manifest( output_root: Path, step_result: dict[str, Any], @@ -278,11 +300,6 @@ def _build_object_manifest( return consolidated -# --------------------------------------------------------------------------- -# main export -# --------------------------------------------------------------------------- - - def export_gym_config( output_root: Path, *, @@ -300,99 +317,81 @@ def export_gym_config( export_dir = export_dir.expanduser().resolve() export_dir.mkdir(parents=True, exist_ok=True) - # ── data sources ──────────────────────────────────────────────────── - step_result = _read_json( - output_root / UNIFIED_SCENE_GEN_STEP / STEP_RESULT_FILENAME - ) + paths = PipelinePaths(output_root) + + step_result = _read_json(paths.step_result(UNIFIED_SCENE_GEN_STEP)) table_fit = step_result.get("table_fit_to_clutter") or {} + if table_fit.get("status") != "ok": + raise RuntimeError( + "Cannot export gym_config because table_fit_to_clutter did not " + f"succeed: status={table_fit.get('status')!r} " + f"reason={table_fit.get('reason', '')}" + ) + manifest_path_value = table_fit.get("manifest_path") or "" table_fit_manifest = _read_json( - _resolve_path(table_fit.get("manifest_path", ""), output_root) - ) - - aligned_by_id: dict[str, dict[str, Any]] = {} - aligned_manifest_path = ( - output_root - / UNIFIED_SCENE_GEN_STEP - / "glb_gen" - / "simready_to_aligned_manifest.json" - ) - if aligned_manifest_path.is_file(): - for item in _read_json(aligned_manifest_path).get("items", []) or []: - if isinstance(item, dict) and item.get("id"): - aligned_by_id[str(item["id"])] = item - - # ── consolidated per-object manifest ───────────────────────────────── - object_manifest = _build_object_manifest( - output_root, step_result, table_fit_manifest, aligned_by_id + _resolve_table_fit_manifest_path( + manifest_path_value=manifest_path_value, + output_root=output_root, + paths=paths, + ) ) - # ── table ──────────────────────────────────────────────────────────── table_info = step_result.get("table") or {} table_desc = str( table_info.get("complete_table_description") or table_info.get("description", "") ).strip() + object_desc_by_id = { + str(item.get("id", "")): str( + item.get("description") or item.get("name") or "" + ).strip() + for item in step_result.get("objects") or [] + if isinstance(item, dict) and item.get("id") + } mesh_assets_dir = export_dir / "mesh_assets" mesh_assets_dir.mkdir(parents=True, exist_ok=True) - table_simready = _resolve_path( - table_info.get("simready_geometry_path") - or table_info.get("mesh_path", ""), + table_fit_output = _resolve_path( + table_fit_manifest.get("table_output_path", ""), output_root, ) - if not table_simready.is_file(): - raise FileNotFoundError(f"Table simready GLB not found: {table_simready}") + if not table_fit_output.is_file(): + raise FileNotFoundError(f"Table-fit GLB not found: {table_fit_output}") table_dst = mesh_assets_dir / "table" / "table_0.glb" table_dst.parent.mkdir(parents=True, exist_ok=True) - shutil.copy2(table_simready, table_dst) - - table_surface_z = _glb_max_z(table_simready) - - uniform_scale = 1.0 - ts = table_fit_manifest.get("table_xy_scale") - if isinstance(ts, dict): - uniform_scale = float(ts.get("uniform_scale", 1.0)) + shutil.copy2(table_fit_output, table_dst) - # ── objects ────────────────────────────────────────────────────────── rigid_objects: list[dict[str, Any]] = [] - total = len(object_manifest) - for idx, (oid, om) in enumerate(object_manifest.items()): - # Copy simready GLB + fitted_objects = [ + item + for item in table_fit_manifest.get("objects", []) or [] + if isinstance(item, dict) and item.get("id") and item.get("path") + ] + total = len(fitted_objects) + for idx, item in enumerate(fitted_objects): + oid = str(item["id"]) safe_name = oid.replace("interact_", "").strip("_") or "object" obj_dir = mesh_assets_dir / safe_name / oid obj_dir.mkdir(parents=True, exist_ok=True) object_dst = obj_dir / f"{oid}.glb" - shutil.copy2(om["simready_path"], object_dst) - - # body_scale. Image-scene alignment may contain a full simready→aligned - # scale; text-scene layout only has the per-object metric scale. - sf = om["scale_factor"] - scale_glb = om.get("transform_scale") or [sf, sf, sf] - body_scale = _glb_scale_to_sim(scale_glb) - - # init_rot - init_rot: list[float] = [0.0, 0.0, 0.0] - if om["rotation_matrix"] is not None: - init_rot = _matrix_to_euler_xyz_deg( - _glb_rotation_to_sim(om["rotation_matrix"]) - ) - - # init_pos = world_bc - rotated_aabb_offset - ro = _rotated_aabb_offsets( - om["simready_path"], om["rotation_matrix"], scale_glb - ) - wbc = om["world_aabb_bottom_center"] - if wbc is not None: - init_pos = [wbc[0] - ro[0], wbc[1] - ro[1], wbc[2] - ro[2]] - else: - init_pos = [-ro[0], -ro[1], table_surface_z - ro[2]] + object_fit_path = _resolve_path(str(item["path"]), output_root) + if not object_fit_path.is_file(): + raise FileNotFoundError(f"Table-fit object GLB not found: {object_fit_path}") + shutil.copy2(object_fit_path, object_dst) + + # Table-fit GLBs already have the relative layout baked into vertices. + # Preview/export should not rebuild placement from simready transforms. + init_pos = [0.0, 0.0, 0.0] + init_rot = [0.0, 0.0, 0.0] + body_scale = [1.0, 1.0, 1.0] + description = object_desc_by_id.get(oid, oid) rigid_objects.append( { "uid": oid, - "description": om["description"], + "description": description, "shape": { "shape_type": "Mesh", "fpath": str(object_dst.relative_to(export_dir)), @@ -406,14 +405,11 @@ def export_gym_config( "max_convex_hull_num": _DEFAULT_MAX_CONVEX_HULL_NUM, } ) - wbc = om["world_aabb_bottom_center"] - wbc_flag = "wbc" if wbc is not None else "fallback" print( - f" [{idx+1}/{total}] [{oid}] {om['description']}" - f" pos={init_pos} rot={init_rot} scale={body_scale} src={wbc_flag}" + f" [{idx+1}/{total}] [{oid}] {description}" + f" pos={init_pos} rot={init_rot} scale={body_scale} src=table_fit_glb" ) - # ── write gym config ───────────────────────────────────────────────── config = { "id": f"Prompt2Scene-{int(time.time() * 1000)}-v0", "max_episodes": 10, @@ -432,7 +428,7 @@ def export_gym_config( "compute_uv": False, }, "attrs": dict(_DEFAULT_TABLE_ATTRS), - "body_scale": [uniform_scale, uniform_scale, 1.0], + "body_scale": [1.0, 1.0, 1.0], "body_type": "kinematic", "init_pos": [0.0, 0.0, 0.0], "init_rot": [0.0, 0.0, 0.0], diff --git a/embodichain/gen_sim/prompt2scene/workflows/image_relations/nodes.py b/embodichain/gen_sim/prompt2scene/workflows/image_relations/nodes.py index ab8b69522..fbaad0e50 100644 --- a/embodichain/gen_sim/prompt2scene/workflows/image_relations/nodes.py +++ b/embodichain/gen_sim/prompt2scene/workflows/image_relations/nodes.py @@ -19,9 +19,9 @@ from pathlib import Path from typing import Any -from embodichain.gen_sim.prompt2scene.agent_tools.clients.image_segmentation_client import ( - decode_rle_mask, - draw_numbered_masks, +from embodichain.gen_sim.prompt2scene.agent_tools.tools.image_segment_filter import ( + filter_group_segments_with_vlm, + filter_segments_with_vlm, ) from embodichain.gen_sim.prompt2scene.workflows.image_relations.schema import ( ImageAssetSegment, @@ -29,8 +29,7 @@ ImageRelationSpec, ) from embodichain.gen_sim.prompt2scene.workflows.request import InputKind -from embodichain.gen_sim.prompt2scene.workflows.image_relations.schema import ( - FILTER_EXTRA_INSTANCES_JSON_SCHEMA, +from embodichain.gen_sim.prompt2scene.prompts.schemas import ( SPATIAL_LAYOUT_JSON_SCHEMA, ) from embodichain.gen_sim.prompt2scene.utils import ( @@ -48,23 +47,24 @@ asset_bbox_label, draw_labeled_bboxes, expand_asset_ids, - filter_group_segments_with_vlm, - filter_segments_with_vlm, merge_non_overlapping_segments, prompt_text, path_token, require_image_path, + segment_area, segment_prompt, segments_from_response, + select_largest_table_segment, + table_segmentation_prompts, + write_table_candidate_debug_image, ) -from embodichain.gen_sim.prompt2scene.workflows.image_relations.prompts import ( - build_filter_extra_instances_messages, +from embodichain.gen_sim.prompt2scene.prompts.builders import ( build_spatial_layout_messages, ) from embodichain.gen_sim.prompt2scene.workflows.image_relations.state import ( ImageRelationsState, ) -from embodichain.gen_sim.prompt2scene.workflows.llm_output import ( +from embodichain.gen_sim.prompt2scene.llms.llm_output import ( call_structured_json_model_step, is_model_output_error, ) @@ -140,7 +140,20 @@ def call_vlm_filter_initial_segments_node( llm: Any, ) -> dict[str, object]: """Ask VLM to remove wrong masks from initial name-based SAM3 output.""" - return filter_segments_with_vlm(state=state, llm=llm, stage="initial") + image_path = require_image_path(state) + art = WorkflowArtifactWriter(state["output_root"], IMAGE_SEGMENTS_STEP) + return filter_segments_with_vlm( + llm=llm, + image_path=image_path, + step_name=IMAGE_SEGMENTS_STEP, + segment_groups=state["segment_groups"], + attempt_count=state["attempt_count"], + errors=state["errors"], + stage="initial", + next_debug_round_name=art.next_debug_round_name, + debug_round_dir=art.debug_round_dir, + write_debug_json=art.write_debug_round_json, + ) def retry_missing_by_candidates_node( state: ImageRelationsState, *, @@ -167,13 +180,21 @@ def retry_missing_by_candidates_node( response=response, source_prompt=prompt, ) + stage_label = f"fallback_{path_token(prompt)}" + round_name_inner = artifact_writer.next_debug_round_name( + label=f"{stage_label}_{group['name']}" + ) + round_dir_inner = artifact_writer.debug_round_dir(round_name_inner) new_segments = filter_group_segments_with_vlm( llm=llm, image_path=image_path, - artifact_writer=artifact_writer, + step_name=IMAGE_SEGMENTS_STEP, group=group, segments=new_segments, - stage=f"fallback_{path_token(prompt)}", + stage=stage_label, + debug_round_name=round_name_inner, + debug_round_dir=round_dir_inner, + write_debug_json=artifact_writer.write_debug_round_json, ) segments = merge_non_overlapping_segments( existing=segments, @@ -324,7 +345,7 @@ def segment_table_node( } segments: list[dict[str, Any]] = [] - for prompt in _table_segmentation_prompts(group): + for prompt in table_segmentation_prompts(group): if len(segments) >= 1: break response = segment_prompt(image_path=image_path, prompt=prompt) @@ -334,14 +355,14 @@ def segment_table_node( response=response, source_prompt=prompt, ) - _write_table_candidate_debug_image( + write_table_candidate_debug_image( image_path=image_path, artifact_writer=artifact_writer, group=group, segments=new_segments, stage=f"table_{path_token(prompt)}", ) - selected_segment = _select_largest_table_segment(new_segments) + selected_segment = select_largest_table_segment(new_segments) if selected_segment is not None: segments = [selected_segment] @@ -387,66 +408,9 @@ def segment_table_node( return {"image_relations": updated_image_relations} -def _table_segmentation_prompts(group: dict[str, Any]) -> list[str]: - """Return table/support segmentation prompts in object-style fallback order.""" - prompts = [prompt_text(group["name"])] - for candidate_name in group["class_candidate"][1:]: - prompts.append(prompt_text(candidate_name)) - description_prompt = str(group.get("description") or "").strip() - if description_prompt: - prompts.append(description_prompt) - unique_prompts: list[str] = [] - for prompt in prompts: - if prompt and prompt not in unique_prompts: - unique_prompts.append(prompt) - return unique_prompts - - -def _write_table_candidate_debug_image( - *, - image_path: Path, - artifact_writer: WorkflowArtifactWriter, - group: dict[str, Any], - segments: list[dict[str, Any]], - stage: str, -) -> None: - """Write table/support candidate mask debug image without VLM filtering.""" - if not segments: - return - round_name = artifact_writer.next_debug_round_name(label=f"{stage}_{group['name']}") - round_dir = artifact_writer.debug_round_dir(round_name) - debug_image_path = draw_numbered_masks( - image_path=image_path, - segments=segments, - output_path=round_dir / "mask.png", - ) - group["debug_images"] = append_unique( - group["debug_images"], - str(debug_image_path), - ) - - -def _select_largest_table_segment( - segments: list[dict[str, Any]], -) -> dict[str, Any] | None: - """Select the largest SAM3 table/support candidate without VLM filtering.""" - if not segments: - return None - return max(segments, key=_segment_area) - - -def _segment_area(segment: dict[str, Any]) -> float: - mask_rle = segment.get("mask_rle") - if mask_rle is not None: - try: - mask = decode_rle_mask(mask_rle).convert("L") - histogram = mask.histogram() - return float(sum(count for value, count in enumerate(histogram) if value)) - except Exception: - pass - x1, y1, x2, y2 = segment["bbox_xyxy"] - return max(0.0, float(x2) - float(x1)) * max(0.0, float(y2) - float(y1)) + artifact_writer.write_step_result(updated_image_relations.to_segmentation_manifest()) + return {"image_relations": updated_image_relations} def call_vlm_spatial_layout_node( @@ -483,11 +447,11 @@ def call_vlm_spatial_layout_node( schema=SPATIAL_LAYOUT_JSON_SCHEMA, messages=messages, context="Image spatial layout", - step_name=IMAGE_SPATIAL_RELATIONS_STEP, - output_root=None, + + attempt_count=attempt_count, - raw_output_label="spatial_layout", - artifact_writer=artifact_writer, + + ) updated_image_relations = apply_spatial_layout_output( image_relations=image_relations, @@ -496,7 +460,9 @@ def call_vlm_spatial_layout_node( artifact_writer.write_step_result(updated_image_relations.to_spatial_manifest()) except Exception as exc: if is_model_output_error(exc) or isinstance(exc, ValueError): - error = format_attempt_error("Image relations spatial layout", attempt_count, exc) + error = format_attempt_error( + "Image relations spatial layout", attempt_count, exc + ) log.log_warning(error) return { "attempt_count": attempt_count, diff --git a/embodichain/gen_sim/prompt2scene/workflows/image_relations/prompts.py b/embodichain/gen_sim/prompt2scene/workflows/image_relations/prompts.py deleted file mode 100644 index f974f442e..000000000 --- a/embodichain/gen_sim/prompt2scene/workflows/image_relations/prompts.py +++ /dev/null @@ -1,113 +0,0 @@ -# ---------------------------------------------------------------------------- -# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ---------------------------------------------------------------------------- - -from __future__ import annotations - -from pathlib import Path -from typing import Any - -from embodichain.gen_sim.prompt2scene.prompts import render_prompt -from embodichain.gen_sim.prompt2scene.utils.io import image_to_data_url - -__all__ = [ - "build_filter_extra_instances_messages", - "build_spatial_layout_messages", -] - -IMAGE_RELATIONS_PROMPT_NAME = "image_relations.yaml" - - -def build_filter_extra_instances_messages( - *, - debug_image_path: Path, - name: str, - description: str, - expected_count: int, - class_candidate: list[str], -) -> list[dict[str, Any]]: - """Build LangChain-compatible messages for VLM extra-mask filtering.""" - return [ - { - "role": "system", - "content": render_prompt( - IMAGE_RELATIONS_PROMPT_NAME, - prompt_key="filter_extra_instances_system", - ), - }, - { - "role": "user", - "content": [ - { - "type": "text", - "text": render_prompt( - IMAGE_RELATIONS_PROMPT_NAME, - { - "name": name.replace("_", " "), - "description": description, - "expected_count": str(expected_count), - "class_candidate": ", ".join( - candidate.replace("_", " ") - for candidate in class_candidate - ), - }, - prompt_key="filter_extra_instances_user", - ), - }, - { - "type": "image_url", - "image_url": {"url": image_to_data_url(debug_image_path)}, - }, - ], - }, - ] - - -def build_spatial_layout_messages( - *, - bbox_name_image_path: Path, - asset_ids: list[str], -) -> list[dict[str, Any]]: - """Build messages for VLM spatial ordering and object-state extraction.""" - return [ - { - "role": "system", - "content": render_prompt( - IMAGE_RELATIONS_PROMPT_NAME, - prompt_key="spatial_layout_system", - ), - }, - { - "role": "user", - "content": [ - { - "type": "text", - "text": render_prompt( - IMAGE_RELATIONS_PROMPT_NAME, - { - "asset_ids": "\n".join( - f"- {asset_id}" for asset_id in asset_ids - ), - }, - prompt_key="spatial_layout_user", - ), - }, - { - "type": "image_url", - "image_url": {"url": image_to_data_url(bbox_name_image_path)}, - }, - ], - }, - ] diff --git a/embodichain/gen_sim/prompt2scene/workflows/image_relations/schema.py b/embodichain/gen_sim/prompt2scene/workflows/image_relations/schema.py index 500f7c702..91dc4583e 100644 --- a/embodichain/gen_sim/prompt2scene/workflows/image_relations/schema.py +++ b/embodichain/gen_sim/prompt2scene/workflows/image_relations/schema.py @@ -19,95 +19,14 @@ from dataclasses import dataclass, field from typing import Any -from embodichain.gen_sim.prompt2scene.workflows.spatial import GRID_VALUE_LIST - __all__ = [ - "FILTER_EXTRA_INSTANCES_JSON_SCHEMA", "ImageAnchor", "ImageAssetLayout", "ImageAssetSegment", "ImageRelationGroup", "ImageRelationSpec", - "SPATIAL_LAYOUT_JSON_SCHEMA", ] -FILTER_EXTRA_INSTANCES_JSON_SCHEMA: dict[str, Any] = { - "title": "FilterExtraImageInstancesOutput", - "type": "object", - "additionalProperties": False, - "properties": { - "extra_instance_numbers": { - "type": "array", - "description": "1-based mask numbers that should be removed.", - "items": {"type": "integer", "minimum": 1}, - }, - "reason": { - "type": "string", - "description": "Brief reason for the removal decision.", - }, - }, - "required": ["extra_instance_numbers", "reason"], -} - -SPATIAL_LAYOUT_JSON_SCHEMA: dict[str, Any] = { - "title": "ImageSpatialLayoutOutput", - "type": "object", - "additionalProperties": False, - "properties": { - "anchor": { - "type": "object", - "additionalProperties": False, - "properties": { - "asset_id": {"type": "string", "minLength": 1}, - "grid": { - "type": "string", - "enum": GRID_VALUE_LIST, - }, - "reason": {"type": "string"}, - }, - "required": ["asset_id", "grid", "reason"], - }, - "x_order": { - "type": "array", - "description": "Asset-id groups ordered from left to right.", - "items": { - "type": "array", - "items": {"type": "string", "minLength": 1}, - "minItems": 1, - }, - "minItems": 1, - }, - "y_order": { - "type": "array", - "description": "Asset-id groups ordered from front to back.", - "items": { - "type": "array", - "items": {"type": "string", "minLength": 1}, - "minItems": 1, - }, - "minItems": 1, - }, - "asset_states": { - "type": "array", - "items": { - "type": "object", - "additionalProperties": True, - "properties": { - "asset_id": {"type": "string", "minLength": 1}, - "is_arbitrary_layout": {"type": "boolean"}, - "reason": {"type": "string", "minLength": 1}, - }, - "required": [ - "asset_id", - "is_arbitrary_layout", - "reason", - ], - }, - }, - }, - "required": ["anchor", "x_order", "y_order", "asset_states"], -} - @dataclass(frozen=True) class ImageAssetSegment: diff --git a/embodichain/gen_sim/prompt2scene/workflows/image_relations/utils.py b/embodichain/gen_sim/prompt2scene/workflows/image_relations/utils.py index 27e3b1b39..5a7070832 100644 --- a/embodichain/gen_sim/prompt2scene/workflows/image_relations/utils.py +++ b/embodichain/gen_sim/prompt2scene/workflows/image_relations/utils.py @@ -25,21 +25,23 @@ ImageSegmentationServerRequest, ImageSegmentationServerResponse, bbox_iou, + decode_rle_mask, draw_labeled_bboxes, draw_numbered_masks, is_usable_segmentation_candidate, sort_segments_by_bbox, ) +from embodichain.gen_sim.prompt2scene.prompts.schemas import ( + SPATIAL_LAYOUT_JSON_SCHEMA, +) from embodichain.gen_sim.prompt2scene.workflows.image_relations.schema import ( - FILTER_EXTRA_INSTANCES_JSON_SCHEMA, ImageAnchor, ImageAssetLayout, ImageAssetSegment, ImageRelationGroup, ImageRelationSpec, - SPATIAL_LAYOUT_JSON_SCHEMA, ) -from embodichain.gen_sim.prompt2scene.workflows.spatial import ( +from embodichain.gen_sim.prompt2scene.agent_tools.tools.spatial_relations import ( GRID_VALUES, validate_exact_asset_id_coverage, ) @@ -47,19 +49,13 @@ from embodichain.gen_sim.prompt2scene.workflows.artifact_writer import ( IMAGE_SEGMENTS_STEP, IMAGE_SPATIAL_RELATIONS_STEP, - RAW_MODEL_OUTPUT_FILENAME, WorkflowArtifactWriter, ) -from embodichain.gen_sim.prompt2scene.workflows.image_relations.prompts import ( - build_filter_extra_instances_messages, +from embodichain.gen_sim.prompt2scene.prompts.builders import ( build_spatial_layout_messages, ) -from embodichain.gen_sim.prompt2scene.workflows.llm_output import ( +from embodichain.gen_sim.prompt2scene.llms.llm_output import ( call_structured_json_model_step, - is_model_output_error, -) -from embodichain.gen_sim.prompt2scene.workflows.stage_errors import ( - format_attempt_error, ) __all__ = [ @@ -68,21 +64,22 @@ "append_unique", "apply_spatial_layout_output", "asset_bbox_label", + "draw_labeled_bboxes", "expand_asset_ids", - "filter_group_segments_with_vlm", - "filter_segments_with_vlm", "merge_non_overlapping_segments", - "draw_labeled_bboxes", "parse_anchor", "parse_asset_states", "parse_order_groups", "path_token", "prompt_text", - "remove_extra_numbered_segments", "require_image_path", + "segment_area", "segment_prompt", "segments_from_response", + "select_largest_table_segment", "sort_segments_by_bbox", + "table_segmentation_prompts", + "write_table_candidate_debug_image", ] MAX_SEGMENT_RETRIES = 1 @@ -297,121 +294,68 @@ def parse_asset_states( return state_by_asset_id -def filter_group_segments_with_vlm( +def table_segmentation_prompts(group: dict[str, Any]) -> list[str]: + """Return table/support segmentation prompts in object-style fallback order.""" + prompts = [prompt_text(group["name"])] + for candidate_name in group["class_candidate"][1:]: + prompts.append(prompt_text(candidate_name)) + description_prompt = str(group.get("description") or "").strip() + if description_prompt: + prompts.append(description_prompt) + + unique_prompts: list[str] = [] + for prompt in prompts: + if prompt and prompt not in unique_prompts: + unique_prompts.append(prompt) + return unique_prompts + + +def write_table_candidate_debug_image( *, - llm: Any, image_path: Path, artifact_writer: WorkflowArtifactWriter, group: dict[str, Any], segments: list[dict[str, Any]], stage: str, -) -> list[dict[str, Any]]: - """Ask VLM to remove wrong or duplicate instances from one SAM3 result.""" - segments = sort_segments_by_bbox(segments) +) -> None: + """Write table/support candidate mask debug image without VLM filtering.""" if not segments: - return segments - - round_name = artifact_writer.next_debug_round_name(label=f"{stage}_{group['name']}") + return + round_name = artifact_writer.next_debug_round_name( + label=f"{stage}_{group['name']}" + ) round_dir = artifact_writer.debug_round_dir(round_name) debug_image_path = draw_numbered_masks( image_path=image_path, segments=segments, output_path=round_dir / "mask.png", ) - group["debug_images"] = append_unique( - group["debug_images"], - str(debug_image_path), - ) - log_api_request_start( - step=IMAGE_SEGMENTS_STEP, - request=f"vlm_filter_{stage}", - debug_image=str(debug_image_path), - ) - messages = build_filter_extra_instances_messages( - debug_image_path=debug_image_path, - name=group["name"], - description=group["description"], - expected_count=group["expected_count"], - class_candidate=group["class_candidate"], - ) - raw_model_output = call_structured_json_model_step( - llm=llm, - schema=FILTER_EXTRA_INSTANCES_JSON_SCHEMA, - messages=messages, - context=f"Image relation {stage} segmentation filtering", - step_name=IMAGE_SEGMENTS_STEP, - output_root=None, - attempt_count=0, - raw_output_writer=lambda payload: artifact_writer.write_debug_round_json( - round_name=round_name, - filename=RAW_MODEL_OUTPUT_FILENAME, - payload=payload, - ), - ) - return remove_extra_numbered_segments( - segments=segments, - raw_model_output=raw_model_output, - ) + debug_images = list(group.get("debug_images") or []) + if str(debug_image_path) not in debug_images: + debug_images.append(str(debug_image_path)) + group["debug_images"] = debug_images -def filter_segments_with_vlm( - *, - state: dict[str, Any], - llm: Any, - stage: str, -) -> dict[str, object]: - """Filter all segment groups with VLM and return an updated state patch.""" - segment_groups = [] - attempt_count = state["attempt_count"] + 1 - image_path = require_image_path(state) - artifact_writer = WorkflowArtifactWriter(state["output_root"], IMAGE_SEGMENTS_STEP) - - try: - for group in state["segment_groups"]: - group = dict(group) - group["segments"] = filter_group_segments_with_vlm( - llm=llm, - image_path=image_path, - artifact_writer=artifact_writer, - group=group, - segments=group["segments"], - stage=stage, - ) - segment_groups.append(group) - except Exception as exc: - if is_model_output_error(exc) or isinstance(exc, ValueError): - error = format_attempt_error("Image relations VLM filter", attempt_count, exc) - log.log_warning(error) - return { - "attempt_count": attempt_count, - "last_error": error, - "errors": state["errors"] + [error], - } - raise - - return { - "attempt_count": attempt_count, - "segment_groups": segment_groups, - "last_error": None, - } +def select_largest_table_segment( + segments: list[dict[str, Any]], +) -> dict[str, Any] | None: + """Select the largest SAM3 table/support candidate without VLM filtering.""" + if not segments: + return None + return max(segments, key=segment_area) -def remove_extra_numbered_segments( - *, - segments: list[dict[str, Any]], - raw_model_output: dict[str, Any], -) -> list[dict[str, Any]]: - """Remove numbered masks flagged as extra by the VLM.""" - extra_numbers = raw_model_output.get("extra_instance_numbers") - if not isinstance(extra_numbers, list): - raise ValueError("extra_instance_numbers must be a list.") - extra_indices = {int(number) - 1 for number in extra_numbers} - if any(index < 0 or index >= len(segments) for index in extra_indices): - raise ValueError("VLM returned an out-of-range extra mask number.") - kept = [ - segment for index, segment in enumerate(segments) if index not in extra_indices - ] - return kept +def segment_area(segment: dict[str, Any]) -> float: + mask_rle = segment.get("mask_rle") + if mask_rle is not None: + try: + mask = decode_rle_mask(mask_rle).convert("L") + histogram = mask.histogram() + return float(sum(count for value, count in enumerate(histogram) if value)) + except Exception: + pass + x1, y1, x2, y2 = segment["bbox_xyxy"] + return max(0.0, float(x2) - float(x1)) * max(0.0, float(y2) - float(y1)) def merge_non_overlapping_segments( @@ -428,7 +372,8 @@ def merge_non_overlapping_segments( if len(merged) >= limit: break if all( - bbox_iou(segment["bbox_xyxy"], other["bbox_xyxy"]) < OVERLAP_IOU_THRESHOLD + bbox_iou(segment["bbox_xyxy"], other["bbox_xyxy"]) + < OVERLAP_IOU_THRESHOLD for other in merged ): merged.append(segment) diff --git a/embodichain/gen_sim/prompt2scene/workflows/paths.py b/embodichain/gen_sim/prompt2scene/workflows/paths.py new file mode 100644 index 000000000..21243fa62 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/workflows/paths.py @@ -0,0 +1,219 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +import re +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +__all__ = [ + "DEBUG_DIRNAME", + "IMAGE_SEGMENTS_STEP", + "IMAGE_SPATIAL_RELATIONS_STEP", + "RAW_MODEL_OUTPUT_FILENAME", + "SCENE_INTAKE_STEP", + "STEP_RESULT_FILENAME", + "TEXT_RELATIONS_STEP", + "UNIFIED_SCENE_GEN_STEP", + "UNIFIED_SCENE_STEP", + "PipelinePaths", + "debug_dir_path", + "debug_round_dir_path", + "next_debug_round_dir_path", + "next_debug_round_name", + "resolve_generated_path", + "step_dir_path", + "step_result_path", +] + +STEP_RESULT_FILENAME = "result.json" +DEBUG_DIRNAME = "debug" +RAW_MODEL_OUTPUT_FILENAME = "raw_model_output.json" + +SCENE_INTAKE_STEP = "scene_intake" +IMAGE_SEGMENTS_STEP = "image_segments" +IMAGE_SPATIAL_RELATIONS_STEP = "image_spatial_relations" +TEXT_RELATIONS_STEP = "text_relations" +UNIFIED_SCENE_STEP = "unified_scene" +UNIFIED_SCENE_GEN_STEP = "unified_scene_gen" + +_DEBUG_ROUND_PATTERN = re.compile(r"^round_(\d+)(?:_|$)") + + +def step_dir_path(output_root: Path, step_name: str) -> Path: + return output_root / step_name + + +def step_result_path(output_root: Path, step_name: str) -> Path: + return step_dir_path(output_root, step_name) / STEP_RESULT_FILENAME + + +def debug_dir_path(output_root: Path, step_name: str) -> Path: + return step_dir_path(output_root, step_name) / DEBUG_DIRNAME + + +def debug_round_dir_path( + output_root: Path, + step_name: str, + round_name: str, +) -> Path: + return debug_dir_path(output_root, step_name) / round_name + + +def next_debug_round_name( + output_root: Path, + step_name: str, + label: str | None = None, +) -> str: + debug_dir = debug_dir_path(output_root, step_name) + max_index = 0 + if debug_dir.is_dir(): + for entry in debug_dir.iterdir(): + if not entry.is_dir(): + continue + match = _DEBUG_ROUND_PATTERN.match(entry.name) + if match is not None: + max_index = max(max_index, int(match.group(1))) + name = f"round_{max_index + 1:03d}" + if label: + name = f"{name}_{_path_token(label)}" + return name + + +def next_debug_round_dir_path( + output_root: Path, + step_name: str, + label: str | None = None, +) -> Path: + return debug_round_dir_path( + output_root, + step_name, + next_debug_round_name(output_root, step_name, label), + ) + + +def resolve_generated_path(value: Any, output_root: Path) -> Path: + if not value: + return Path() + path = Path(str(value)).expanduser() + if path.is_absolute(): + return path.resolve() + return (output_root.expanduser().resolve() / path).resolve() + + +def _path_token(value: str) -> str: + token = "".join(c if c.isalnum() else "_" for c in value) + return token.strip("_")[:80] or "round" + + +@dataclass(frozen=True) +class PipelinePaths: + output_root: Path + + def __post_init__(self) -> None: + object.__setattr__(self, "output_root", self.output_root.expanduser().resolve()) + + @property + def scene_intake_dir(self) -> Path: + return self.output_root / SCENE_INTAKE_STEP + + @property + def image_segments_dir(self) -> Path: + return self.output_root / IMAGE_SEGMENTS_STEP + + @property + def image_spatial_relations_dir(self) -> Path: + return self.output_root / IMAGE_SPATIAL_RELATIONS_STEP + + @property + def text_relations_dir(self) -> Path: + return self.output_root / TEXT_RELATIONS_STEP + + @property + def unified_scene_dir(self) -> Path: + return self.output_root / UNIFIED_SCENE_STEP + + @property + def unified_scene_gen_dir(self) -> Path: + return self.output_root / UNIFIED_SCENE_GEN_STEP + + def step_result(self, step: str) -> Path: + return step_result_path(self.output_root, step) + + @property + def scene_intake_result(self) -> Path: + return self.step_result(SCENE_INTAKE_STEP) + + @property + def image_segments_result(self) -> Path: + return self.step_result(IMAGE_SEGMENTS_STEP) + + @property + def unified_scene_result(self) -> Path: + return self.step_result(UNIFIED_SCENE_STEP) + + def resolve_scene_result(self, explicit_path: Path | None) -> Path: + if explicit_path is not None: + return explicit_path.expanduser().resolve() + result = self.unified_scene_result + if result.is_file(): + return result + legacy = self.unified_scene_dir / "results.json" + return legacy if legacy.is_file() else result + + @property + def gen_image_dir(self) -> Path: + return self.unified_scene_gen_dir / "image_gen" + + @property + def gen_glb_dir(self) -> Path: + return self.unified_scene_gen_dir / "glb_gen" + + @property + def gen_debug_dir(self) -> Path: + return self.unified_scene_gen_dir / "debug" + + @property + def text_clutter_dir(self) -> Path: + return self.gen_glb_dir / "text_clutter_settled" + + @property + def table_fit_dir(self) -> Path: + return self.gen_glb_dir / "table_fit_to_clutter" + + @property + def simready_to_aligned_manifest(self) -> Path: + return self.gen_glb_dir / "simready_to_aligned_manifest.json" + + @property + def table_fit_manifest(self) -> Path: + return self.table_fit_dir / "table_fit_to_clutter_manifest.json" + + @property + def gym_export_dir(self) -> Path: + return self.output_root / "gym_export" + + @property + def gym_config(self) -> Path: + return self.gym_export_dir / "gym_config.json" + + def prepare_generation_dirs(self) -> tuple[Path, Path, Path]: + dirs = (self.gen_image_dir, self.gen_glb_dir, self.gen_debug_dir) + for d in dirs: + d.mkdir(parents=True, exist_ok=True) + return dirs diff --git a/embodichain/gen_sim/prompt2scene/workflows/scene_intake/nodes.py b/embodichain/gen_sim/prompt2scene/workflows/scene_intake/nodes.py index 8c7baf55c..5b579d78d 100644 --- a/embodichain/gen_sim/prompt2scene/workflows/scene_intake/nodes.py +++ b/embodichain/gen_sim/prompt2scene/workflows/scene_intake/nodes.py @@ -18,8 +18,10 @@ from typing import Any -from embodichain.gen_sim.prompt2scene.workflows.scene_intake.schema import ( +from embodichain.gen_sim.prompt2scene.prompts.schemas import ( SCENE_INTAKE_JSON_SCHEMA, +) +from embodichain.gen_sim.prompt2scene.workflows.scene_intake.schema import ( SceneIntakeSpec, ) from embodichain.gen_sim.prompt2scene.utils import ( @@ -30,14 +32,14 @@ SCENE_INTAKE_STEP, WorkflowArtifactWriter, ) -from embodichain.gen_sim.prompt2scene.workflows.llm_output import ( +from embodichain.gen_sim.prompt2scene.llms.llm_output import ( StructuredModelCallError, call_structured_json_model_step, ) from embodichain.gen_sim.prompt2scene.workflows.stage_errors import ( format_attempt_error, ) -from embodichain.gen_sim.prompt2scene.workflows.scene_intake.prompts import ( +from embodichain.gen_sim.prompt2scene.prompts.builders import ( build_scene_intake_messages, build_scene_intake_verifier_messages, ) @@ -85,11 +87,11 @@ def call_vlm_scene_intake_node( schema=SCENE_INTAKE_JSON_SCHEMA, messages=state["messages"], context="Scene intake", - step_name=SCENE_INTAKE_STEP, - output_root=None, + + attempt_count=attempt_count, - raw_output_label="extract", - artifact_writer=artifact_writer, + + ) except StructuredModelCallError as exc: error = format_attempt_error("Scene intake", attempt_count, exc) @@ -161,11 +163,11 @@ def call_vlm_verify_scene_intake_node( schema=SCENE_INTAKE_JSON_SCHEMA, messages=messages, context="Scene intake verifier", - step_name=SCENE_INTAKE_STEP, - output_root=None, + + attempt_count=attempt_count, - raw_output_label="verify", - artifact_writer=artifact_writer, + + ) except StructuredModelCallError as exc: error = format_attempt_error("Scene intake verifier", attempt_count, exc) diff --git a/embodichain/gen_sim/prompt2scene/workflows/scene_intake/prompts.py b/embodichain/gen_sim/prompt2scene/workflows/scene_intake/prompts.py deleted file mode 100644 index 421ec979b..000000000 --- a/embodichain/gen_sim/prompt2scene/workflows/scene_intake/prompts.py +++ /dev/null @@ -1,202 +0,0 @@ -# ---------------------------------------------------------------------------- -# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ---------------------------------------------------------------------------- - -from __future__ import annotations - -import json -from typing import Any - -from embodichain.gen_sim.prompt2scene.prompts import render_prompt -from embodichain.gen_sim.prompt2scene.workflows.request import ( - InputKind, - Prompt2SceneInput, -) -from embodichain.gen_sim.prompt2scene.utils.io import image_to_data_url -from embodichain.gen_sim.prompt2scene.workflows.scene_intake.schema import ( - SceneIntakeSpec, -) - -__all__ = [ - "build_scene_intake_messages", - "build_scene_intake_verifier_messages", -] - -SCENE_INTAKE_PROMPT_NAME = "scene_intake.yaml" - - -def build_scene_intake_messages(request: Prompt2SceneInput) -> list[dict[str, Any]]: - """Build LangChain-compatible messages for scene intake.""" - if request.input_kind == InputKind.TEXT: - return _build_text_messages(request) - return _build_image_messages(request) - - -def build_scene_intake_verifier_messages( - *, - request: Prompt2SceneInput, - scene_intake: SceneIntakeSpec, -) -> list[dict[str, Any]]: - """Build messages for scene-intake group and count verification.""" - table_draft: dict[str, object] = { - "name": scene_intake.table.name, - "description": scene_intake.table.description, - "complete_table_description": ( - scene_intake.table.complete_table_description - ), - "is_complete_visible_table": ( - scene_intake.table.is_complete_visible_table - ), - "class_candidate": list(scene_intake.table.class_candidate), - } - if scene_intake.table.object_coverage_percent is not None: - table_draft["object_coverage_percent"] = ( - scene_intake.table.object_coverage_percent - ) - scene_intake_json = json.dumps( - { - "table": table_draft, - "assets": [ - { - "name": asset.name, - "description": asset.description, - "class_candidate": list(asset.class_candidate), - "count": asset.count, - } - for asset in scene_intake.assets - ], - }, - ensure_ascii=False, - indent=2, - ) - if request.input_kind == InputKind.TEXT: - return _build_text_verifier_messages( - request=request, - scene_intake_json=scene_intake_json, - ) - return _build_image_verifier_messages( - request=request, - scene_intake_json=scene_intake_json, - ) - - -def _build_text_messages(request: Prompt2SceneInput) -> list[dict[str, Any]]: - return [ - { - "role": "system", - "content": render_prompt(SCENE_INTAKE_PROMPT_NAME, prompt_key="text_system"), - }, - { - "role": "user", - "content": render_prompt( - SCENE_INTAKE_PROMPT_NAME, - {"text": request.text or ""}, - prompt_key="text_user", - ), - }, - ] - - -def _build_image_messages(request: Prompt2SceneInput) -> list[dict[str, Any]]: - image_path = request.image_path - if image_path is None: - raise ValueError("Image input requires image_path.") - - return [ - { - "role": "system", - "content": render_prompt(SCENE_INTAKE_PROMPT_NAME, prompt_key="image_system"), - }, - { - "role": "user", - "content": [ - { - "type": "text", - "text": render_prompt( - SCENE_INTAKE_PROMPT_NAME, - prompt_key="image_user", - ), - }, - { - "type": "image_url", - "image_url": {"url": image_to_data_url(image_path)}, - }, - ], - }, - ] - - -def _build_text_verifier_messages( - *, - request: Prompt2SceneInput, - scene_intake_json: str, -) -> list[dict[str, Any]]: - return [ - { - "role": "system", - "content": render_prompt( - SCENE_INTAKE_PROMPT_NAME, - prompt_key="verifier_system", - ), - }, - { - "role": "user", - "content": render_prompt( - SCENE_INTAKE_PROMPT_NAME, - { - "text": request.text or "", - "scene_intake_json": scene_intake_json, - }, - prompt_key="verifier_text_user", - ), - }, - ] - - -def _build_image_verifier_messages( - *, - request: Prompt2SceneInput, - scene_intake_json: str, -) -> list[dict[str, Any]]: - image_path = request.image_path - if image_path is None: - raise ValueError("Image input requires image_path.") - - return [ - { - "role": "system", - "content": render_prompt( - SCENE_INTAKE_PROMPT_NAME, - prompt_key="verifier_system", - ), - }, - { - "role": "user", - "content": [ - { - "type": "text", - "text": render_prompt( - SCENE_INTAKE_PROMPT_NAME, - {"scene_intake_json": scene_intake_json}, - prompt_key="verifier_image_user", - ), - }, - { - "type": "image_url", - "image_url": {"url": image_to_data_url(image_path)}, - }, - ], - }, - ] diff --git a/embodichain/gen_sim/prompt2scene/workflows/scene_intake/schema.py b/embodichain/gen_sim/prompt2scene/workflows/scene_intake/schema.py index 31b55e6df..e0bcf3a59 100644 --- a/embodichain/gen_sim/prompt2scene/workflows/scene_intake/schema.py +++ b/embodichain/gen_sim/prompt2scene/workflows/scene_intake/schema.py @@ -25,149 +25,12 @@ ) __all__ = [ - "SCENE_INTAKE_JSON_SCHEMA", "SceneIntakeAsset", "SceneIntakeInputRecord", "SceneIntakeSpec", "SceneIntakeTable", ] -SCENE_INTAKE_JSON_SCHEMA: dict[str, Any] = { - "title": "SceneIntakeModelOutput", - "description": ( - "Objects and table information extracted from a text or image input." - ), - "type": "object", - "additionalProperties": False, - "properties": { - "table": { - "type": "object", - "additionalProperties": False, - "properties": { - "name": { - "type": "string", - "description": ( - "Canonical English class name for the visible table " - "or tabletop target, such as table, desk, dining_table, " - "coffee_table, workbench, or tabletop." - ), - }, - "description": { - "type": "string", - "minLength": 20, - "maxLength": 180, - "description": ( - "One concise standalone appearance description of the " - "visible table or tabletop region." - ), - }, - "complete_table_description": { - "type": "string", - "minLength": 20, - "maxLength": 220, - "description": ( - "One concise standalone description of a complete table " - "asset for text-to-3D generation, matching the visible " - "tabletop color, material, and texture." - ), - }, - "is_complete_visible_table": { - "type": "boolean", - "description": ( - "For image input, whether a mostly complete table is " - "visible and suitable as the final table geometry source. " - "For text input, this should be false." - ), - }, - "class_candidate": { - "type": "array", - "minItems": 5, - "maxItems": 5, - "description": ( - "Exactly five likely class names for segmenting the " - "visible table or tabletop target." - ), - "items": { - "type": "string", - "minLength": 1, - }, - }, - "object_coverage_percent": { - "type": "integer", - "enum": [10, 30, 50, 70], - "description": ( - "For image input with a complete visible table ONLY: " - "choose the closest coverage bucket for objects on the " - "tabletop: 10 (mostly empty, a few small objects), " - "30 (lightly cluttered), 50 (moderately cluttered), " - "70 (densely packed). Omit this field entirely for " - "text input or when is_complete_visible_table is false." - ), - }, - }, - "required": [ - "name", - "description", - "complete_table_description", - "is_complete_visible_table", - "class_candidate", - ], - }, - "assets": { - "type": "array", - "description": ( - "Object category groups on or intended for the tabletop scene." - ), - "items": { - "type": "object", - "additionalProperties": False, - "properties": { - "name": { - "type": "string", - "description": ( - "Canonical English object name, singular, " - "snake_case preferred." - ), - }, - "description": { - "type": "string", - "minLength": 20, - "maxLength": 180, - "description": ( - "One concise appearance description of the object for " - "image and 3D geometry generation." - ), - }, - "class_candidate": { - "type": "array", - "minItems": 5, - "maxItems": 5, - "description": ( - "Exactly five likely object class names for later " - "image detection or segmentation." - ), - "items": { - "type": "string", - "minLength": 1, - }, - }, - "count": { - "type": "integer", - "description": ( - "Number of repeated instances in this object category " - "group. Only group objects that can share the same name, " - "description, and class_candidate list." - ), - "minimum": 1, - }, - }, - "required": ["name", "description", "class_candidate", "count"], - }, - }, - }, - "required": ["table", "assets"], -} - @dataclass(frozen=True) class SceneIntakeInputRecord: diff --git a/embodichain/gen_sim/prompt2scene/workflows/text_relations/nodes.py b/embodichain/gen_sim/prompt2scene/workflows/text_relations/nodes.py index 67b1fc3c1..5adc8eef5 100644 --- a/embodichain/gen_sim/prompt2scene/workflows/text_relations/nodes.py +++ b/embodichain/gen_sim/prompt2scene/workflows/text_relations/nodes.py @@ -19,8 +19,10 @@ from typing import Any from embodichain.gen_sim.prompt2scene.workflows.request import InputKind -from embodichain.gen_sim.prompt2scene.workflows.text_relations.schema import ( +from embodichain.gen_sim.prompt2scene.prompts.schemas import ( TEXT_RELATIONS_JSON_SCHEMA, +) +from embodichain.gen_sim.prompt2scene.workflows.text_relations.schema import ( TextRelationSpec, ) from embodichain.gen_sim.prompt2scene.utils import ( @@ -31,14 +33,14 @@ TEXT_RELATIONS_STEP, WorkflowArtifactWriter, ) -from embodichain.gen_sim.prompt2scene.workflows.llm_output import ( +from embodichain.gen_sim.prompt2scene.llms.llm_output import ( StructuredModelCallError, call_structured_json_model_step, ) from embodichain.gen_sim.prompt2scene.workflows.stage_errors import ( format_attempt_error, ) -from embodichain.gen_sim.prompt2scene.workflows.text_relations.prompts import ( +from embodichain.gen_sim.prompt2scene.prompts.builders import ( build_text_relation_messages, ) from embodichain.gen_sim.prompt2scene.workflows.text_relations.state import ( @@ -93,11 +95,11 @@ def call_llm_text_relations_node( schema=TEXT_RELATIONS_JSON_SCHEMA, messages=state["messages"], context="Text relations", - step_name=TEXT_RELATIONS_STEP, - output_root=None, + + attempt_count=attempt_count, - raw_output_label="extract", - artifact_writer=artifact_writer, + + ) except StructuredModelCallError as exc: error = format_attempt_error("Text relations", attempt_count, exc) diff --git a/embodichain/gen_sim/prompt2scene/workflows/text_relations/prompts.py b/embodichain/gen_sim/prompt2scene/workflows/text_relations/prompts.py deleted file mode 100644 index a6f02e4f6..000000000 --- a/embodichain/gen_sim/prompt2scene/workflows/text_relations/prompts.py +++ /dev/null @@ -1,55 +0,0 @@ -# ---------------------------------------------------------------------------- -# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ---------------------------------------------------------------------------- - -from __future__ import annotations - -from typing import Any - -from embodichain.gen_sim.prompt2scene.prompts import render_prompt -from embodichain.gen_sim.prompt2scene.workflows.request import Prompt2SceneInput -from embodichain.gen_sim.prompt2scene.workflows.scene_intake.schema import ( - SceneIntakeSpec, -) - -__all__ = ["build_text_relation_messages"] - -TEXT_RELATIONS_PROMPT_NAME = "text_relations.yaml" - - -def build_text_relation_messages( - *, - request: Prompt2SceneInput, - scene_intake: SceneIntakeSpec, -) -> list[dict[str, Any]]: - """Build messages for explicit text spatial-relation extraction.""" - asset_names = "\n".join(f"- {asset.name}" for asset in scene_intake.assets) - return [ - { - "role": "system", - "content": render_prompt(TEXT_RELATIONS_PROMPT_NAME, prompt_key="system"), - }, - { - "role": "user", - "content": render_prompt( - TEXT_RELATIONS_PROMPT_NAME, - { - "asset_names": asset_names, - "text": request.text or "", - }, - prompt_key="user", - ), - }, - ] diff --git a/embodichain/gen_sim/prompt2scene/workflows/text_relations/schema.py b/embodichain/gen_sim/prompt2scene/workflows/text_relations/schema.py index db2e513ff..329f4dd21 100644 --- a/embodichain/gen_sim/prompt2scene/workflows/text_relations/schema.py +++ b/embodichain/gen_sim/prompt2scene/workflows/text_relations/schema.py @@ -19,74 +19,13 @@ from dataclasses import dataclass, field from typing import Any -from embodichain.gen_sim.prompt2scene.workflows.spatial import ( - GRID_VALUE_LIST, - RELATION_VALUE_LIST, -) - __all__ = [ - "TEXT_RELATIONS_JSON_SCHEMA", "TextObjectLayout", "TextObjectRelation", "TextRelationSpec", "TextTableConstraint", ] -TEXT_RELATIONS_JSON_SCHEMA: dict[str, Any] = { - "title": "TextRelationsOutput", - "type": "object", - "additionalProperties": False, - "properties": { - "object_relations": { - "type": "array", - "items": { - "type": "object", - "additionalProperties": False, - "properties": { - "subject": {"type": "string", "minLength": 1}, - "relation": { - "type": "string", - "enum": RELATION_VALUE_LIST, - }, - "object": {"type": "string", "minLength": 1}, - "evidence": {"type": "string", "minLength": 1}, - }, - "required": ["subject", "relation", "object", "evidence"], - }, - }, - "table_constraints": { - "type": "array", - "items": { - "type": "object", - "additionalProperties": False, - "properties": { - "asset": {"type": "string", "minLength": 1}, - "grid": { - "type": "string", - "enum": GRID_VALUE_LIST, - }, - "evidence": {"type": "string", "minLength": 1}, - }, - "required": ["asset", "grid", "evidence"], - }, - }, - "object_layouts": { - "type": "array", - "items": { - "type": "object", - "additionalProperties": False, - "properties": { - "asset": {"type": "string", "minLength": 1}, - "is_arbitrary_layout": {"type": "boolean"}, - "reason": {"type": "string", "minLength": 1}, - }, - "required": ["asset", "is_arbitrary_layout", "reason"], - }, - }, - }, - "required": ["object_relations", "table_constraints", "object_layouts"], -} - @dataclass(frozen=True) class TextObjectRelation: diff --git a/embodichain/gen_sim/prompt2scene/workflows/text_relations/utils.py b/embodichain/gen_sim/prompt2scene/workflows/text_relations/utils.py index 58002713b..dad38551e 100644 --- a/embodichain/gen_sim/prompt2scene/workflows/text_relations/utils.py +++ b/embodichain/gen_sim/prompt2scene/workflows/text_relations/utils.py @@ -21,7 +21,7 @@ from embodichain.gen_sim.prompt2scene.workflows.scene_intake.schema import ( SceneIntakeSpec, ) -from embodichain.gen_sim.prompt2scene.workflows.spatial import ( +from embodichain.gen_sim.prompt2scene.agent_tools.tools.spatial_relations import ( GRID_VALUES, RELATION_VALUES, ) diff --git a/embodichain/gen_sim/prompt2scene/workflows/unified_scene/utils.py b/embodichain/gen_sim/prompt2scene/workflows/unified_scene/utils.py index 49e4a70cb..7071176f7 100644 --- a/embodichain/gen_sim/prompt2scene/workflows/unified_scene/utils.py +++ b/embodichain/gen_sim/prompt2scene/workflows/unified_scene/utils.py @@ -23,7 +23,7 @@ ImageAnchor, ImageRelationSpec, ) -from embodichain.gen_sim.prompt2scene.workflows.spatial import ( +from embodichain.gen_sim.prompt2scene.agent_tools.tools.spatial_relations import ( assign_grids_from_anchor_and_orders, derive_relations_from_orders, transitive_relation_closure, diff --git a/embodichain/gen_sim/prompt2scene/workflows/unified_scene_gen/nodes.py b/embodichain/gen_sim/prompt2scene/workflows/unified_scene_gen/nodes.py index e12e41f12..347678ccc 100644 --- a/embodichain/gen_sim/prompt2scene/workflows/unified_scene_gen/nodes.py +++ b/embodichain/gen_sim/prompt2scene/workflows/unified_scene_gen/nodes.py @@ -45,16 +45,16 @@ from embodichain.gen_sim.prompt2scene.agent_tools.tools.image_scene_asset_generation import ( generate_image_scene_assets, ) -from embodichain.gen_sim.prompt2scene.workflows.unified_scene_gen.paths import ( - UnifiedScenePaths, +from embodichain.gen_sim.prompt2scene.workflows.paths import ( + PipelinePaths, ) -from embodichain.gen_sim.prompt2scene.workflows.unified_scene_gen.prompts import ( +from embodichain.gen_sim.prompt2scene.prompts.builders import ( build_text_metric_scale_messages, ) -from embodichain.gen_sim.prompt2scene.workflows.unified_scene_gen.schema import ( +from embodichain.gen_sim.prompt2scene.prompts.schemas import ( IMAGE_METRIC_SCALE_JSON_SCHEMA, ) -from embodichain.gen_sim.prompt2scene.workflows.unified_scene_gen.scene_update import ( +from embodichain.gen_sim.prompt2scene.workflows.unified_scene_gen.utils import ( update_unified_scene, ) @@ -72,7 +72,7 @@ def load_unified_scene_input_kind_node( state: UnifiedSceneGenState, ) -> dict[str, object]: """Load unified-scene output and determine the generation route.""" - paths = UnifiedScenePaths(state["output_root"]) + paths = PipelinePaths(state["output_root"]) result_path = paths.resolve_scene_result(state["unified_scene_result_path"]) if not result_path.is_file(): raise FileNotFoundError(f"Unified scene result not found: {result_path}") @@ -110,12 +110,12 @@ def generate_text_assets_node( if unified_scene is None: return {"generation_status": "no_unified_scene"} - paths = UnifiedScenePaths(state["output_root"]) + paths = PipelinePaths(state["output_root"]) output_root = paths.output_root image_gen_dir, glb_gen_dir, debug_dir = paths.prepare_generation_dirs() log_info( "generate_text_assets started " - f"output_dir={output_root / UNIFIED_SCENE_GEN_STEP}" + f"output_dir={paths.unified_scene_gen_dir}" ) table_spec = unified_scene.get("table") or {} @@ -191,12 +191,12 @@ def generate_image_assets_node(state: UnifiedSceneGenState) -> dict[str, object] if unified_scene is None: return {"generation_status": "no_unified_scene"} - paths = UnifiedScenePaths(state["output_root"]) + paths = PipelinePaths(state["output_root"]) output_root = paths.output_root image_gen_dir, glb_gen_dir, debug_dir = paths.prepare_generation_dirs() log_info( "generate_image_assets started " - f"output_dir={output_root / UNIFIED_SCENE_GEN_STEP}" + f"output_dir={paths.unified_scene_gen_dir}" ) segments_path = paths.image_segments_result @@ -261,7 +261,7 @@ def fit_image_table_to_clutter_node(state: UnifiedSceneGenState) -> dict[str, ob if state.get("input_kind") != "image": return {} - paths = UnifiedScenePaths(state["output_root"]) + paths = PipelinePaths(state["output_root"]) output_root = paths.output_root output_dir = paths.table_fit_dir output_dir.mkdir(parents=True, exist_ok=True) @@ -303,7 +303,7 @@ def generate_text_clutter_layout_node( if state.get("input_kind") != "text": return {} - paths = UnifiedScenePaths(state["output_root"]) + paths = PipelinePaths(state["output_root"]) output_root = paths.output_root output_dir = paths.text_clutter_dir output_dir.mkdir(parents=True, exist_ok=True) @@ -353,7 +353,7 @@ def fit_text_table_to_clutter_node( if state.get("input_kind") != "text": return {} - paths = UnifiedScenePaths(state["output_root"]) + paths = PipelinePaths(state["output_root"]) output_root = paths.output_root table_result = state.get("table_result") settle_result = state.get("text_clutter_settle_result") diff --git a/embodichain/gen_sim/prompt2scene/workflows/unified_scene_gen/paths.py b/embodichain/gen_sim/prompt2scene/workflows/unified_scene_gen/paths.py deleted file mode 100644 index c4af80541..000000000 --- a/embodichain/gen_sim/prompt2scene/workflows/unified_scene_gen/paths.py +++ /dev/null @@ -1,102 +0,0 @@ -# ---------------------------------------------------------------------------- -# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ---------------------------------------------------------------------------- - -from __future__ import annotations - -from dataclasses import dataclass -from pathlib import Path -from typing import Any - -from embodichain.gen_sim.prompt2scene.workflows.artifact_writer import ( - IMAGE_SEGMENTS_STEP, - STEP_RESULT_FILENAME, - UNIFIED_SCENE_GEN_STEP, - UNIFIED_SCENE_STEP, -) - -__all__ = ["UnifiedScenePaths", "resolve_generated_path"] - - -def resolve_generated_path(value: Any, output_root: Path) -> Path: - """Resolve an absolute or output-root-relative generated artifact path.""" - if not value: - return Path() - path = Path(str(value)).expanduser() - if path.is_absolute(): - return path.resolve() - return (output_root.expanduser().resolve() / path).resolve() - - -@dataclass(frozen=True) -class UnifiedScenePaths: - """High-level paths owned by the unified-scene generation workflow.""" - - output_root: Path - - def __post_init__(self) -> None: - object.__setattr__( - self, - "output_root", - self.output_root.expanduser().resolve(), - ) - - @property - def workflow_root(self) -> Path: - return self.output_root / UNIFIED_SCENE_GEN_STEP - - @property - def image_gen_dir(self) -> Path: - return self.workflow_root / "image_gen" - - @property - def glb_gen_dir(self) -> Path: - return self.workflow_root / "glb_gen" - - @property - def debug_dir(self) -> Path: - return self.workflow_root / "debug" - - @property - def text_clutter_dir(self) -> Path: - return self.glb_gen_dir / "text_clutter_settled" - - @property - def table_fit_dir(self) -> Path: - return self.glb_gen_dir / "table_fit_to_clutter" - - @property - def image_segments_result(self) -> Path: - return self.output_root / IMAGE_SEGMENTS_STEP / STEP_RESULT_FILENAME - - def prepare_generation_dirs(self) -> tuple[Path, Path, Path]: - """Create and return the workflow's high-level generation directories.""" - directories = (self.image_gen_dir, self.glb_gen_dir, self.debug_dir) - for directory in directories: - directory.mkdir(parents=True, exist_ok=True) - return directories - - def resolve_scene_result(self, explicit_path: Path | None) -> Path: - """Resolve the unified-scene result produced by the preceding workflow.""" - if explicit_path is not None: - return explicit_path.expanduser().resolve() - - scene_dir = self.output_root / UNIFIED_SCENE_STEP - result_path = scene_dir / STEP_RESULT_FILENAME - if result_path.is_file(): - return result_path - - legacy_path = scene_dir / "results.json" - return legacy_path if legacy_path.is_file() else result_path diff --git a/embodichain/gen_sim/prompt2scene/workflows/unified_scene_gen/prompts.py b/embodichain/gen_sim/prompt2scene/workflows/unified_scene_gen/prompts.py deleted file mode 100644 index 1543acfb6..000000000 --- a/embodichain/gen_sim/prompt2scene/workflows/unified_scene_gen/prompts.py +++ /dev/null @@ -1,141 +0,0 @@ -# ---------------------------------------------------------------------------- -# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ---------------------------------------------------------------------------- - -from __future__ import annotations - -import json -from pathlib import Path -from typing import Any - -from embodichain.gen_sim.prompt2scene.prompts import render_prompt -from embodichain.gen_sim.prompt2scene.utils.io import image_to_data_url - -__all__ = [ - "build_image_metric_scale_messages", - "build_text_metric_scale_messages", - "build_up_down_flip_check_messages", -] - -UNIFIED_SCENE_GEN_PROMPT_NAME = "unified_scene_gen.yaml" - - -def build_image_metric_scale_messages( - *, - bbox_name_image_path: Path, - objects_json: list[dict[str, Any]], -) -> list[dict[str, Any]]: - """Build messages for image-scene object metric scale estimation.""" - return [ - { - "role": "system", - "content": render_prompt( - UNIFIED_SCENE_GEN_PROMPT_NAME, - prompt_key="image_metric_scale_system", - ), - }, - { - "role": "user", - "content": [ - { - "type": "text", - "text": render_prompt( - UNIFIED_SCENE_GEN_PROMPT_NAME, - { - "objects_json": json.dumps( - objects_json, - ensure_ascii=False, - indent=2, - ), - }, - prompt_key="image_metric_scale_user", - ), - }, - { - "type": "image_url", - "image_url": {"url": image_to_data_url(bbox_name_image_path)}, - }, - ], - }, - ] - - -def build_text_metric_scale_messages( - *, - user_text: str, - objects_json: list[dict[str, Any]], -) -> list[dict[str, Any]]: - """Build messages for text-scene object metric scale estimation.""" - return [ - { - "role": "system", - "content": render_prompt( - UNIFIED_SCENE_GEN_PROMPT_NAME, - prompt_key="text_metric_scale_system", - ), - }, - { - "role": "user", - "content": render_prompt( - UNIFIED_SCENE_GEN_PROMPT_NAME, - { - "user_text": user_text, - "objects_json": json.dumps( - objects_json, - ensure_ascii=False, - indent=2, - ), - }, - prompt_key="text_metric_scale_user", - ), - }, - ] - - -def build_up_down_flip_check_messages( - *, - original_image_path: Path, - comparison_image_path: Path, -) -> list[dict[str, Any]]: - """Build messages for VLM support-normal up/down flip verification.""" - return [ - { - "role": "system", - "content": render_prompt( - UNIFIED_SCENE_GEN_PROMPT_NAME, - prompt_key="up_down_flip_check_system", - ), - }, - { - "role": "user", - "content": [ - { - "type": "text", - "text": render_prompt( - UNIFIED_SCENE_GEN_PROMPT_NAME, - prompt_key="up_down_flip_check_user", - ), - }, - { - "type": "image_url", - "image_url": {"url": image_to_data_url(original_image_path)}, - }, - { - "type": "image_url", - "image_url": {"url": image_to_data_url(comparison_image_path)}, - }, - ], - }, - ] diff --git a/embodichain/gen_sim/prompt2scene/workflows/unified_scene_gen/schema.py b/embodichain/gen_sim/prompt2scene/workflows/unified_scene_gen/schema.py deleted file mode 100644 index b22fcebba..000000000 --- a/embodichain/gen_sim/prompt2scene/workflows/unified_scene_gen/schema.py +++ /dev/null @@ -1,71 +0,0 @@ -# ---------------------------------------------------------------------------- -# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ---------------------------------------------------------------------------- - -from __future__ import annotations - -from typing import Any - -__all__ = [ - "IMAGE_METRIC_SCALE_JSON_SCHEMA", - "UP_DOWN_FLIP_CHECK_JSON_SCHEMA", -] - -UP_DOWN_FLIP_CHECK_JSON_SCHEMA: dict[str, Any] = { - "title": "AlignedUpDownFlipCheckOutput", - "type": "object", - "additionalProperties": False, - "properties": { - "selected_number": {"type": "integer", "enum": [1, 2]}, - "confidence": {"type": "number", "minimum": 0.0, "maximum": 1.0}, - "reason": {"type": "string"}, - }, - "required": ["selected_number", "confidence", "reason"], -} - -IMAGE_METRIC_SCALE_JSON_SCHEMA: dict[str, Any] = { - "title": "ImageMetricScaleEstimate", - "type": "object", - "additionalProperties": False, - "properties": { - "object_scales": { - "type": "array", - "items": { - "type": "object", - "additionalProperties": False, - "properties": { - "object_id": {"type": "string"}, - "bbox_dims_cm": { - "type": "array", - "minItems": 3, - "maxItems": 3, - "items": { - "type": "number", - "minimum": 1.0e-6, - }, - }, - "confidence": { - "type": "number", - "minimum": 0.0, - "maximum": 1.0, - }, - "reason": {"type": "string"}, - }, - "required": ["object_id", "bbox_dims_cm", "confidence", "reason"], - }, - }, - }, - "required": ["object_scales"], -} diff --git a/embodichain/gen_sim/prompt2scene/workflows/unified_scene_gen/scene_update.py b/embodichain/gen_sim/prompt2scene/workflows/unified_scene_gen/utils.py similarity index 100% rename from embodichain/gen_sim/prompt2scene/workflows/unified_scene_gen/scene_update.py rename to embodichain/gen_sim/prompt2scene/workflows/unified_scene_gen/utils.py From 6a203083af463c0be892247666ade529bdc40b21 Mon Sep 17 00:00:00 2001 From: Muzi Wong <178915912+MuziWong@users.noreply.github.com> Date: Wed, 1 Jul 2026 19:00:02 +0800 Subject: [PATCH 6/7] 1. Fixed segmentation bug; 2. Fixed image input object position relation bug; 3. Fixed the flip objects clutter bug; 4. Added 2d layout info when exporting the gym project for better scene modifying... --- .../image_segmentation_client/utils.py | 60 ++- .../blender_rendering_manager/manager.py | 6 +- .../managers/geometry_manager/manager.py | 39 ++ .../managers/geometry_manager/utils.py | 14 + .../managers/matplotlib_manager/manager.py | 17 +- .../managers/simready_manager/manager.py | 15 +- .../tools/image_layout_alignment.py | 112 ++++- .../tools/image_scene_asset_generation.py | 11 +- .../agent_tools/tools/image_segment_filter.py | 86 ++++ .../agent_tools/tools/table_clutter_fit.py | 47 ++- .../gen_sim/prompt2scene/prompts/builders.py | 41 ++ .../prompts/data/image_relations.yaml | 107 ++++- .../prompts/data/scene_intake.yaml | 42 +- .../gen_sim/prompt2scene/prompts/schemas.py | 20 + .../prompt2scene/workflows/gym_export.py | 389 ++++++++++++++++-- .../workflows/image_relations/nodes.py | 41 +- .../workflows/image_relations/utils.py | 84 ++++ .../workflows/scene_intake/utils.py | 33 +- 18 files changed, 1036 insertions(+), 128 deletions(-) diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/clients/image_segmentation_client/utils.py b/embodichain/gen_sim/prompt2scene/agent_tools/clients/image_segmentation_client/utils.py index 834573588..3bef48923 100644 --- a/embodichain/gen_sim/prompt2scene/agent_tools/clients/image_segmentation_client/utils.py +++ b/embodichain/gen_sim/prompt2scene/agent_tools/clients/image_segmentation_client/utils.py @@ -19,7 +19,8 @@ from pathlib import Path from typing import Any -from PIL import Image, ImageDraw, ImageFont +import numpy as np +from PIL import Image, ImageDraw, ImageFilter, ImageFont from embodichain.gen_sim.prompt2scene.agent_tools.clients.image_segmentation_client.schemas import ( ImageSegmentationCandidate, @@ -154,12 +155,12 @@ def draw_numbered_masks( draw_overlay = ImageDraw.Draw(overlay) font = _load_label_font(image.width) colors = [ - (255, 64, 64, 110), - (64, 160, 255, 110), - (64, 220, 120, 110), - (255, 190, 64, 110), - (190, 96, 255, 110), - (255, 96, 190, 110), + (255, 64, 64, 255), + (64, 160, 255, 255), + (64, 220, 120, 255), + (255, 190, 64, 255), + (190, 96, 255, 255), + (255, 96, 190, 255), ] for index, segment in enumerate(segments, start=1): @@ -170,9 +171,10 @@ def draw_numbered_masks( if mask.size != image.size: mask = mask.resize(image.size, Image.Resampling.NEAREST) color = colors[(index - 1) % len(colors)] + outline = _mask_outline(mask) color_layer = Image.new("RGBA", image.size, color) transparent = Image.new("RGBA", image.size) - overlay.alpha_composite(Image.composite(color_layer, transparent, mask)) + overlay.alpha_composite(Image.composite(color_layer, transparent, outline)) _draw_mask_label( draw=draw_overlay, segment=segment, @@ -298,25 +300,45 @@ def _draw_mask_label( label: str, font: ImageFont.ImageFont, ) -> None: - bbox = mask.getbbox() - if bbox is None: + anchor = _mask_visible_pixel_centroid(mask) + if anchor is None: x1, y1, x2, y2 = segment["bbox_xyxy"] x = float(x1 + x2) * 0.5 y = float(y1 + y2) * 0.5 else: - x1, y1, x2, y2 = bbox - x = float(x1 + x2) * 0.5 - y = float(y1 + y2) * 0.5 + x, y = anchor - label_box = draw.textbbox((x, y), label, font=font) + label_box = draw.textbbox((0, 0), label, font=font) padding = 8 + label_width = label_box[2] - label_box[0] + label_height = label_box[3] - label_box[1] + text_x = x - label_width * 0.5 + text_y = y - label_height * 0.5 draw.rectangle( ( - label_box[0] - padding, - label_box[1] - padding, - label_box[2] + padding, - label_box[3] + padding, + text_x - padding, + text_y - padding, + text_x + label_width + padding, + text_y + label_height + padding, ), fill="red", + outline="white", + width=3, ) - draw.text((x, y), label, fill="white", font=font) + draw.text((text_x, text_y), label, fill="white", font=font) + + +def _mask_visible_pixel_centroid(mask: Image.Image) -> tuple[float, float] | None: + """Return the centroid of actual visible mask pixels, not the bbox center.""" + alpha = np.asarray(mask.convert("L"), dtype=np.uint8) + ys, xs = np.nonzero(alpha > 0) + if len(xs) == 0: + return None + return float(np.mean(xs)), float(np.mean(ys)) + + +def _mask_outline(mask: Image.Image) -> Image.Image: + """Return a thick binary outline so overlays do not recolor the object.""" + alpha = mask.convert("L") + edge = alpha.filter(ImageFilter.FIND_EDGES) + return edge.filter(ImageFilter.MaxFilter(5)) diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/managers/blender_rendering_manager/manager.py b/embodichain/gen_sim/prompt2scene/agent_tools/managers/blender_rendering_manager/manager.py index 8617f2975..6ae3d0c6e 100644 --- a/embodichain/gen_sim/prompt2scene/agent_tools/managers/blender_rendering_manager/manager.py +++ b/embodichain/gen_sim/prompt2scene/agent_tools/managers/blender_rendering_manager/manager.py @@ -158,9 +158,9 @@ def _front_oblique_script(glb_paths: list[Path], output_path: Path) -> str: light = bpy.data.objects.new("front_oblique_area_light", light_data) bpy.context.collection.objects.link(light) light.location = camera.location -light_data.energy = 600.0 +light_data.energy = 350.0 light_data.size = max(span_x, span_y) * 2.0 -bpy.context.scene.world.color = (1.0, 1.0, 1.0) +bpy.context.scene.world.color = (0.90, 0.90, 0.90) try: bpy.context.scene.render.engine = "BLENDER_EEVEE_NEXT" except Exception: @@ -169,7 +169,7 @@ def _front_oblique_script(glb_paths: list[Path], output_path: Path) -> str: bpy.context.scene.render.resolution_y = 768 bpy.context.scene.render.film_transparent = False bpy.context.scene.view_settings.view_transform = "Standard" -bpy.context.scene.view_settings.look = "Medium High Contrast" +bpy.context.scene.view_settings.look = "Medium Contrast" bpy.context.scene.render.filepath = output_path bpy.ops.render.render(write_still=True) """ diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/managers/geometry_manager/manager.py b/embodichain/gen_sim/prompt2scene/agent_tools/managers/geometry_manager/manager.py index fa42ead4a..ec7a18c6f 100644 --- a/embodichain/gen_sim/prompt2scene/agent_tools/managers/geometry_manager/manager.py +++ b/embodichain/gen_sim/prompt2scene/agent_tools/managers/geometry_manager/manager.py @@ -221,6 +221,12 @@ def table_fit_uniform_xy_scale_transform(**kwargs: Any) -> Any: return geometry_utils._table_fit_uniform_xy_scale_transform(**kwargs) + @staticmethod + def table_fit_uniform_scale_transform(**kwargs: Any) -> Any: + from . import utils as geometry_utils + + return geometry_utils._table_fit_uniform_scale_transform(**kwargs) + @staticmethod def table_fit_safe_positive_ratio(numerator: float, denominator: float) -> float: from . import utils as geometry_utils @@ -336,6 +342,39 @@ def mesh_aabb_size(mesh: Any) -> Any: raise ValueError(f"Mesh AABB size must be positive, got {size.tolist()}.") return size + @staticmethod + def mesh_pca_bbox_size(mesh: Any) -> Any: + """Return bbox extents in the mesh PCA frame. + + This is used for metric-scale estimation because it is less sensitive + to arbitrary object yaw/tilt than a world-axis AABB. + """ + vertices = np.asarray(mesh.vertices, dtype=np.float64) + if vertices.ndim != 2 or vertices.shape[1] != 3 or len(vertices) < 3: + return GeometryManager.mesh_aabb_size(mesh) + + centered = vertices - np.mean(vertices, axis=0) + cov = np.cov(centered, rowvar=False) + if cov.shape != (3, 3) or not np.all(np.isfinite(cov)): + return GeometryManager.mesh_aabb_size(mesh) + + eigvals, eigvecs = np.linalg.eigh(cov) + order = np.argsort(eigvals)[::-1] + axes = eigvecs[:, order] + if np.linalg.det(axes) < 0.0: + axes[:, -1] *= -1.0 + + projected = centered @ axes + size = projected.max(axis=0) - projected.min(axis=0) + if np.any(size <= 0.0) or not np.all(np.isfinite(size)): + return GeometryManager.mesh_aabb_size(mesh) + return size + + @staticmethod + def mesh_metric_bbox_size(mesh: Any) -> Any: + """Return the bbox size used by metric-scale estimation.""" + return GeometryManager.mesh_pca_bbox_size(mesh) + @staticmethod def bbox_ratio(size: Any) -> Any: """Return bbox dimensions normalized by the largest axis.""" diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/managers/geometry_manager/utils.py b/embodichain/gen_sim/prompt2scene/agent_tools/managers/geometry_manager/utils.py index be502fbbe..f7a5bcfec 100644 --- a/embodichain/gen_sim/prompt2scene/agent_tools/managers/geometry_manager/utils.py +++ b/embodichain/gen_sim/prompt2scene/agent_tools/managers/geometry_manager/utils.py @@ -495,6 +495,20 @@ def _table_fit_uniform_xy_scale_transform( return center @ scale_mat @ uncenter +def _table_fit_uniform_scale_transform( + *, + center_xy: np.ndarray, + scale: float, +) -> np.ndarray: + center = np.eye(4, dtype=np.float64) + center[:3, 3] = [float(center_xy[0]), float(center_xy[1]), 0.0] + uncenter = np.eye(4, dtype=np.float64) + uncenter[:3, 3] = [-float(center_xy[0]), -float(center_xy[1]), 0.0] + scale_mat = np.eye(4, dtype=np.float64) + scale_mat[:3, :3] *= float(scale) + return center @ scale_mat @ uncenter + + def _table_fit_safe_positive_ratio(numerator: float, denominator: float) -> float: return max(float(numerator) / max(float(denominator), 1.0e-6), 1.0e-6) diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/managers/matplotlib_manager/manager.py b/embodichain/gen_sim/prompt2scene/agent_tools/managers/matplotlib_manager/manager.py index 1feb13c3f..4b5c58721 100644 --- a/embodichain/gen_sim/prompt2scene/agent_tools/managers/matplotlib_manager/manager.py +++ b/embodichain/gen_sim/prompt2scene/agent_tools/managers/matplotlib_manager/manager.py @@ -155,7 +155,22 @@ def render_image_comparison( (axes[1], second_image, request.second_label), ): ax.imshow(image) - ax.set_title(label, fontsize=16, loc="left") + ax.text( + 0.03, + 0.92, + label, + transform=ax.transAxes, + ha="left", + va="top", + fontsize=16, + color="white", + bbox={ + "boxstyle": "round,pad=0.25", + "facecolor": "black", + "edgecolor": "none", + "alpha": 0.55, + }, + ) ax.axis("off") fig.tight_layout() fig.savefig(output_path, dpi=self._dpi, facecolor="white") diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/managers/simready_manager/manager.py b/embodichain/gen_sim/prompt2scene/agent_tools/managers/simready_manager/manager.py index ae46a552f..9b265ddfb 100644 --- a/embodichain/gen_sim/prompt2scene/agent_tools/managers/simready_manager/manager.py +++ b/embodichain/gen_sim/prompt2scene/agent_tools/managers/simready_manager/manager.py @@ -367,11 +367,12 @@ def build_object_payload(objects): payload = [] for obj in objects: mesh = geom.load_mesh(LoadMeshRequest(mesh_path=obj.mesh_path)).mesh - normalized_bbox_size_m = GeometryManager.mesh_aabb_size(mesh) + normalized_bbox_size_m = GeometryManager.mesh_metric_bbox_size(mesh) payload.append({ "object_id": obj.object_id, "object_name": obj.object_name, "object_description": obj.object_description, + "normalized_bbox_method": "pca_bbox", "normalized_bbox_size_m": normalized_bbox_size_m.tolist(), "normalized_bbox_ratio": GeometryManager.bbox_ratio( normalized_bbox_size_m @@ -466,6 +467,7 @@ def select_candidate(*, object_id, object_name, object_description, bbox_dims_cm "object_id": object_id, "object_name": object_name, "object_description": object_description, + "normalized_bbox_method": "pca_bbox", "normalized_bbox_size_m": normalized_bbox_size_m.tolist(), "normalized_bbox_size_cm": nbs_cm.tolist(), "normalized_bbox_ratio": GeometryManager.bbox_ratio( @@ -579,10 +581,14 @@ def compute_global_from_object_scenes(request): if srs.shape != (3,) or np.any(srs <= 0.0): skipped.append({"id": object_id, "reason": "invalid_normalized_bbox_size_m"}) continue - cb = np.asarray(GeometryManager.scene_to_mesh(scene).bounds) - cs = cb[1] - cb[0] + cs = np.asarray( + GeometryManager.mesh_metric_bbox_size( + GeometryManager.scene_to_mesh(scene) + ), + dtype=np.float64, + ) if cs.shape != (3,) or np.any(cs <= 0.0): - skipped.append({"id": object_id, "reason": "invalid_current_scene_aabb"}) + skipped.append({"id": object_id, "reason": "invalid_current_scene_bbox"}) continue geo_ratio = np.sort(cs) / np.sort(srs) geo_scale = float(np.median(geo_ratio)) @@ -641,4 +647,3 @@ def compute_global_from_object_scenes(request): f"clamped to [{request.min_scale:.2f}, {request.max_scale:.2f}]." ), } - diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/tools/image_layout_alignment.py b/embodichain/gen_sim/prompt2scene/agent_tools/tools/image_layout_alignment.py index ede21b08b..f6ca1d9dd 100644 --- a/embodichain/gen_sim/prompt2scene/agent_tools/tools/image_layout_alignment.py +++ b/embodichain/gen_sim/prompt2scene/agent_tools/tools/image_layout_alignment.py @@ -173,6 +173,15 @@ def _export_support_aligned_layout_glbs( object_scenes = selected_candidate["object_scenes"] selected_extra_transform = selected_candidate["extra_transform"] apply_up_down_flip = selected_variant == "flipped" + complete_table_relative_scale_hint = _complete_table_relative_scale_hint( + table=table, + support_reference_scene=support_reference_scene, + object_scenes=object_scenes, + table_alignment_matrix=selected_extra_transform + @ center_transform + @ normal_alignment, + trimesh=trimesh, + ) global_metric_scale = MetricScaleManager.compute_global_from_object_scenes( GlobalMetricScaleRequest( @@ -227,6 +236,7 @@ def _export_support_aligned_layout_glbs( "pre_metric_scale_alignment_matrix": alignment_matrix.tolist(), "global_metric_scale": global_metric_scale, "final_clutter_2d_aabb_cm": final_clutter_aabb_2d_cm, + "complete_table_relative_scale_hint": complete_table_relative_scale_hint, "internal_up_axis": [0.0, 0.0, 1.0], "glb_output_up_axis": [0.0, 1.0, 0.0], "glb_output_axis_transform": output_axis_transform.tolist(), @@ -256,6 +266,73 @@ def _export_support_aligned_layout_glbs( } +def _complete_table_relative_scale_hint( + *, + table: dict[str, Any], + support_reference_scene: Any, + object_scenes: list[tuple[str, Any]], + table_alignment_matrix: np.ndarray, + trimesh: Any, +) -> dict[str, Any]: + if not table.get("is_complete_visible_table"): + return { + "status": "skipped", + "reason": "table_is_not_complete_visible", + } + if not object_scenes: + return { + "status": "skipped", + "reason": "missing_object_scenes", + } + try: + table_scene = GeometryManager.copy_scene_with_transform( + support_reference_scene, + table_alignment_matrix, + ) + raw_clutter_bounds = GeometryManager.table_fit_scene_union_bounds( + [scene for _, scene in object_scenes], + trimesh=trimesh, + ) + raw_clutter_size_xy = GeometryManager.xy_aabb_size(raw_clutter_bounds) + raw_table_mesh = GeometryManager.scene_to_mesh(table_scene, trimesh=trimesh) + raw_table_support = GeometryManager.detect_table_fit_support_quad( + raw_table_mesh, + target_aspect=float( + raw_clutter_size_xy[0] / max(float(raw_clutter_size_xy[1]), 1.0e-6) + ), + ) + raw_table_support_size_xy = np.asarray( + raw_table_support["size_xy"], + dtype=np.float64, + ) + ratio_xy = raw_table_support_size_xy / np.maximum( + raw_clutter_size_xy, + 1.0e-6, + ) + if not np.all(np.isfinite(ratio_xy)) or np.any(ratio_xy <= 0.0): + return { + "status": "skipped", + "reason": "invalid_raw_relative_size", + } + return { + "status": "ok", + "method": "complete_table_sam3d_raw_support_to_clutter_ratio", + "raw_table_support_size_xy": raw_table_support_size_xy.tolist(), + "raw_clutter_size_xy": raw_clutter_size_xy.tolist(), + "support_to_clutter_size_ratio_xy": ratio_xy.tolist(), + "raw_table_support_quad": raw_table_support, + "note": ( + "Ratio is unitless and is computed before metric scaling; " + "table fit later applies one uniform XYZ scale to the simready table." + ), + } + except Exception: + return { + "status": "failed", + "reason": traceback.format_exc(), + } + + def _build_up_down_alignment_candidates( *, object_scenes: list[tuple[str, Any]], @@ -429,6 +506,19 @@ def _resolve_generated_path(value: Any, output_root: Path) -> Path: return (output_root / path).resolve() +def _write_json_file(path: Path, payload: dict[str, Any]) -> None: + try: + import json + + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text( + json.dumps(payload, ensure_ascii=False, indent=2) + "\n", + encoding="utf-8", + ) + except Exception: + pass + + def _run_aligned_up_down_flip_vlm_check( *, llm: Any | None, @@ -446,6 +536,7 @@ def _run_aligned_up_down_flip_vlm_check( } if not normal_object_scenes or not flipped_object_scenes: result["reason"] = "missing_object_scenes" + _write_json_file(output_dir / "up_down_flip_selection.json", result) return result try: @@ -473,9 +564,11 @@ def _run_aligned_up_down_flip_vlm_check( ) if llm is None: result["reason"] = "missing_llm" + _write_json_file(output_dir / "up_down_flip_selection.json", result) return result if original_image_path is None or not original_image_path.is_file(): result["reason"] = "missing_original_image" + _write_json_file(output_dir / "up_down_flip_selection.json", result) return result raw_model_output = call_structured_json_model_step( @@ -486,21 +579,12 @@ def _run_aligned_up_down_flip_vlm_check( comparison_image_path=comparison_image_path, ), context="Unified scene aligned up-down flip check", - step_name=UNIFIED_SCENE_STEP, - output_root=None, attempt_count=0, + raw_output_writer=lambda payload: _write_json_file( + output_dir / "vlm_flip_check_result.json", + payload, + ), ) - # Persist VLM raw output alongside the comparison renders - try: - import json as _json - - vlm_result_path = output_dir / "vlm_flip_check_result.json" - vlm_result_path.write_text( - _json.dumps(raw_model_output, ensure_ascii=False, indent=2), - encoding="utf-8", - ) - except Exception: - pass confidence = float(raw_model_output.get("confidence", 0.0)) selected_number = int(raw_model_output.get("selected_number", 1)) if selected_number not in {1, 2}: @@ -524,6 +608,7 @@ def _run_aligned_up_down_flip_vlm_check( "reason": str(raw_model_output.get("reason", "")), } ) + _write_json_file(output_dir / "up_down_flip_selection.json", result) return result except Exception: result.update( @@ -532,4 +617,5 @@ def _run_aligned_up_down_flip_vlm_check( "reason": traceback.format_exc(), } ) + _write_json_file(output_dir / "up_down_flip_selection.json", result) return result diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/tools/image_scene_asset_generation.py b/embodichain/gen_sim/prompt2scene/agent_tools/tools/image_scene_asset_generation.py index 4bc8cbb5b..d57da062e 100644 --- a/embodichain/gen_sim/prompt2scene/agent_tools/tools/image_scene_asset_generation.py +++ b/embodichain/gen_sim/prompt2scene/agent_tools/tools/image_scene_asset_generation.py @@ -430,6 +430,10 @@ def generate_image_scene_assets( generated_object["aligned_geometry_path"] = aligned_object[ "aligned_geometry_path" ] + if isinstance(generated_table, dict): + generated_table["complete_table_relative_scale_hint"] = ( + alignment_result.get("complete_table_relative_scale_hint") + ) except Exception as exc: status = "failed" failure_reason = traceback.format_exc() @@ -463,6 +467,7 @@ def generate_image_scene_assets( "simready_geometry_path", "aligned_geometry_path", "mesh_path", + "complete_table_relative_scale_hint", ) object_fields = ( "id", @@ -491,7 +496,11 @@ def generate_image_scene_assets( workflow_alignment = ( { key: alignment_result[key] - for key in ("status", "final_clutter_2d_aabb_cm") + for key in ( + "status", + "final_clutter_2d_aabb_cm", + "complete_table_relative_scale_hint", + ) if key in alignment_result } if alignment_result is not None diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/tools/image_segment_filter.py b/embodichain/gen_sim/prompt2scene/agent_tools/tools/image_segment_filter.py index 28d3a99ab..62b88bdd8 100644 --- a/embodichain/gen_sim/prompt2scene/agent_tools/tools/image_segment_filter.py +++ b/embodichain/gen_sim/prompt2scene/agent_tools/tools/image_segment_filter.py @@ -19,7 +19,10 @@ from pathlib import Path from typing import Any, Callable +import numpy as np + from embodichain.gen_sim.prompt2scene.agent_tools.clients.image_segmentation_client import ( + decode_rle_mask, draw_numbered_masks, sort_segments_by_bbox, ) @@ -40,11 +43,14 @@ __all__ = [ "filter_group_segments_with_vlm", + "filter_overlapping_confirmed_segments", "filter_segments_with_vlm", "remove_extra_numbered_segments", ] DebugWriter = Callable[[str, str, dict[str, Any]], Path] +CONFIRMED_MASK_COVERAGE_THRESHOLD = 0.85 +CONFIRMED_MASK_IOU_THRESHOLD = 0.70 def remove_extra_numbered_segments( @@ -77,6 +83,7 @@ def filter_group_segments_with_vlm( debug_round_name: str, debug_round_dir: Path, write_debug_json: DebugWriter, + confirmed_segments: list[dict[str, Any]] | None = None, ) -> list[dict[str, Any]]: """Ask VLM to remove wrong or duplicate instances from one SAM3 result. @@ -84,6 +91,10 @@ def filter_group_segments_with_vlm( *debug_round_dir*, and *write_debug_json* so the tool does not depend on workflow internals. """ + segments = filter_overlapping_confirmed_segments( + segments=segments, + confirmed_segments=confirmed_segments or [], + ) segments = sort_segments_by_bbox(segments) if not segments: return segments @@ -150,6 +161,7 @@ def filter_segments_with_vlm( """ result_groups: list[dict[str, Any]] = [] current_attempt = attempt_count + 1 + confirmed_segments: list[dict[str, Any]] = [] try: for group in segment_groups: @@ -167,7 +179,9 @@ def filter_segments_with_vlm( debug_round_name=round_name, debug_round_dir=round_dir, write_debug_json=write_debug_json, + confirmed_segments=confirmed_segments, ) + confirmed_segments.extend(group["segments"]) result_groups.append(group) except Exception as exc: if is_model_output_error(exc) or isinstance(exc, ValueError): @@ -187,3 +201,75 @@ def filter_segments_with_vlm( "segment_groups": result_groups, "last_error": None, } + + +def filter_overlapping_confirmed_segments( + *, + segments: list[dict[str, Any]], + confirmed_segments: list[dict[str, Any]], +) -> list[dict[str, Any]]: + """Drop candidates that almost exactly duplicate an already accepted mask.""" + if not segments or not confirmed_segments: + return segments + filtered: list[dict[str, Any]] = [] + confirmed_masks = [ + _segment_mask_array(segment) + for segment in confirmed_segments + if segment.get("mask_rle") is not None + ] + confirmed_masks = [mask for mask in confirmed_masks if mask is not None] + if not confirmed_masks: + return segments + + for segment in segments: + candidate_mask = _segment_mask_array(segment) + if candidate_mask is None: + filtered.append(segment) + continue + if any( + _is_almost_confirmed_duplicate( + candidate_mask=candidate_mask, + confirmed_mask=confirmed_mask, + ) + for confirmed_mask in confirmed_masks + ): + continue + filtered.append(segment) + removed_count = len(segments) - len(filtered) + if removed_count: + log.log_info( + "removed overlapping confirmed segment candidates " + f"before VLM count={removed_count}" + ) + return filtered + + +def _segment_mask_array(segment: dict[str, Any]) -> np.ndarray | None: + mask_rle = segment.get("mask_rle") + if mask_rle is None: + return None + try: + return np.asarray(decode_rle_mask(mask_rle).convert("L"), dtype=np.uint8) > 0 + except Exception: + return None + + +def _is_almost_confirmed_duplicate( + *, + candidate_mask: np.ndarray, + confirmed_mask: np.ndarray, +) -> bool: + if candidate_mask.shape != confirmed_mask.shape: + return False + candidate_area = int(np.count_nonzero(candidate_mask)) + confirmed_area = int(np.count_nonzero(confirmed_mask)) + if candidate_area <= 0 or confirmed_area <= 0: + return False + intersection = int(np.count_nonzero(candidate_mask & confirmed_mask)) + union = candidate_area + confirmed_area - intersection + candidate_covered = intersection / float(candidate_area) + iou = intersection / float(union) if union > 0 else 0.0 + return ( + candidate_covered >= CONFIRMED_MASK_COVERAGE_THRESHOLD + or iou >= CONFIRMED_MASK_IOU_THRESHOLD + ) diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/tools/table_clutter_fit.py b/embodichain/gen_sim/prompt2scene/agent_tools/tools/table_clutter_fit.py index e3e6f5296..6a5789890 100644 --- a/embodichain/gen_sim/prompt2scene/agent_tools/tools/table_clutter_fit.py +++ b/embodichain/gen_sim/prompt2scene/agent_tools/tools/table_clutter_fit.py @@ -50,7 +50,7 @@ def _gravity_settle_table_fit_internal_z_scene( *, z_to_y: np.ndarray, sim_device: str, -) -> Any: +) -> tuple[Any, np.ndarray]: sim = SimulationManager(headless=True, sim_device=sim_device) with tempfile.TemporaryDirectory(prefix="p2s_table_fit_gravity_") as tmp: tmp_path = Path(tmp) @@ -63,9 +63,10 @@ def _gravity_settle_table_fit_internal_z_scene( initial_height=0.05, ) ) + gravity_transform = np.asarray(result.final_pose, dtype=np.float64) settled = scene.copy() - settled.apply_transform(np.asarray(result.final_pose, dtype=np.float64)) - return settled + settled.apply_transform(gravity_transform) + return settled, gravity_transform def fit_table_to_clutter( @@ -146,6 +147,8 @@ def fit_table_to_clutter( trimesh=trimesh, y_to_z=y_to_z, ) + table_fit_transform = np.eye(4, dtype=np.float64) + table_mesh = GeometryManager.scene_to_mesh(table_scene, trimesh=trimesh) clutter_aabb = clutter_result.get("clutter_2d_aabb_cm") or {} clutter_size = clutter_aabb.get("size_xy", [1.0, 1.0]) @@ -172,7 +175,7 @@ def fit_table_to_clutter( trimesh=trimesh, ) - # Compute the required table size and uniform scale. + # Compute the required table size and optional uniform scale. clutter_size_cm = (clutter_bounds[1, :2] - clutter_bounds[0, :2]) * 100.0 if object_coverage_percent is not None: support_occupancy_ratio = float( @@ -180,6 +183,20 @@ def fit_table_to_clutter( ) occupancy = float(np.clip(support_occupancy_ratio, 0.1, 1.0)) required_size_cm = clutter_size_cm / occupancy + 2.0 * float(margin_cm) + scale_method = "fit_to_clutter_occupancy_margin" + relative_scale_hint = None + if table_result.get("is_complete_visible_table"): + hint = table_result.get("complete_table_relative_scale_hint") + if isinstance(hint, dict) and hint.get("status") == "ok": + ratio_xy = np.asarray( + hint.get("support_to_clutter_size_ratio_xy", []), + dtype=np.float64, + ) + if ratio_xy.shape == (2,) and np.all(np.isfinite(ratio_xy)): + ratio_xy = np.maximum(ratio_xy, 1.0) + required_size_cm = clutter_size_cm * ratio_xy + scale_method = "complete_table_sam3d_raw_relative_uniform_xyz" + relative_scale_hint = hint support_size_cm = np.asarray(initial_support["size_xy"], dtype=np.float64) * 100.0 scale_x = GeometryManager.table_fit_safe_positive_ratio( required_size_cm[0], @@ -190,19 +207,27 @@ def fit_table_to_clutter( support_size_cm[1], ) uniform_scale = max(scale_x, scale_y) - table_scale_transform = GeometryManager.table_fit_uniform_xy_scale_transform( - center_xy=np.asarray(initial_support["center_xy"], dtype=np.float64), - scale=uniform_scale, - ) + if scale_method == "complete_table_sam3d_raw_relative_uniform_xyz": + table_scale_transform = GeometryManager.table_fit_uniform_scale_transform( + center_xy=np.asarray(initial_support["center_xy"], dtype=np.float64), + scale=uniform_scale, + ) + else: + table_scale_transform = GeometryManager.table_fit_uniform_xy_scale_transform( + center_xy=np.asarray(initial_support["center_xy"], dtype=np.float64), + scale=uniform_scale, + ) table_scene.apply_transform(table_scale_transform) + table_fit_transform = table_scale_transform @ table_fit_transform # Settle the table under gravity. if gravity_settle_table: - table_scene = _gravity_settle_table_fit_internal_z_scene( + table_scene, gravity_transform = _gravity_settle_table_fit_internal_z_scene( table_scene, z_to_y=z_to_y, sim_device=sim_device, ) + table_fit_transform = gravity_transform @ table_fit_transform # Reposition the table at the origin. final_table_mesh = GeometryManager.scene_to_mesh(table_scene, trimesh=trimesh) @@ -217,6 +242,7 @@ def fit_table_to_clutter( table_shift = np.eye(4, dtype=np.float64) table_shift[:3, 3] = [-support_center[0], -support_center[1], -table_bottom_z] table_scene.apply_transform(table_shift) + table_fit_transform = table_shift @ table_fit_transform support_z_after = float((support_center + table_shift[:3, 3])[2]) # Measure the table surface height. @@ -311,15 +337,18 @@ def fit_table_to_clutter( "gravity_settle_table": gravity_settle_table, "table_bottom_z_after_shift": 0.0, "support_z_after_shift": support_z_after, + "table_fit_transform": table_fit_transform.tolist(), "initial_support_quad": initial_support, "final_support_quad_centered": final_support_centered, "clutter_2d_aabb_cm": final_clutter_aabb_cm, "required_support_size_cm": required_size_cm.tolist(), "table_xy_scale": { + "method": scale_method, "uniform_scale": uniform_scale, "scale_x_raw": scale_x, "scale_y_raw": scale_y, "support_size_before_scale_cm": support_size_cm.tolist(), + "complete_table_relative_scale_hint": relative_scale_hint, }, "fit_check": { "fits_width": float(final_clutter_aabb_cm["size_xy"][0]) diff --git a/embodichain/gen_sim/prompt2scene/prompts/builders.py b/embodichain/gen_sim/prompt2scene/prompts/builders.py index 8596c32dc..99018e329 100644 --- a/embodichain/gen_sim/prompt2scene/prompts/builders.py +++ b/embodichain/gen_sim/prompt2scene/prompts/builders.py @@ -29,6 +29,7 @@ "build_scene_intake_messages", "build_scene_intake_verifier_messages", "build_spatial_layout_messages", + "build_spatial_layout_verifier_messages", "build_text_metric_scale_messages", "build_text_relation_messages", "build_up_down_flip_check_messages", @@ -262,6 +263,46 @@ def build_spatial_layout_messages( ] +def build_spatial_layout_verifier_messages( + *, + bbox_name_image_path: Path, + asset_ids: list[str], + draft_spatial_layout_json: str, +) -> list[dict[str, Any]]: + """Build messages for VLM spatial ordering verification.""" + return [ + { + "role": "system", + "content": render_prompt( + IMAGE_RELATIONS_PROMPT, + prompt_key="spatial_layout_verifier_system", + ), + }, + { + "role": "user", + "content": [ + { + "type": "text", + "text": render_prompt( + IMAGE_RELATIONS_PROMPT, + { + "asset_ids": "\n".join( + f"- {asset_id}" for asset_id in asset_ids + ), + "draft_spatial_layout_json": draft_spatial_layout_json, + }, + prompt_key="spatial_layout_verifier_user", + ), + }, + { + "type": "image_url", + "image_url": {"url": image_to_data_url(bbox_name_image_path)}, + }, + ], + }, + ] + + def build_text_relation_messages( diff --git a/embodichain/gen_sim/prompt2scene/prompts/data/image_relations.yaml b/embodichain/gen_sim/prompt2scene/prompts/data/image_relations.yaml index 50ed69647..a6e4082f7 100644 --- a/embodichain/gen_sim/prompt2scene/prompts/data/image_relations.yaml +++ b/embodichain/gen_sim/prompt2scene/prompts/data/image_relations.yaml @@ -12,8 +12,8 @@ filter_extra_instances_system: | - One target object description. - The expected number of target instances. - A short candidate class list for that target object. - - One image with numbered colored masks drawn over candidate segmentation - results for that target object. + - One image with numbered colored mask outlines drawn over candidate + segmentation results for that target object. Your only task is to choose which numbered masks should be removed so the remaining masks best match the requested object class, target description, and @@ -25,6 +25,16 @@ filter_extra_instances_system: | + - The colored mask outlines and red number labels are annotations only. Do NOT + treat annotation colors as the object's real color, material, texture, label, + logo, or surface appearance. + - The red number label is only an approximate index marker. It may sit near, + above, or partly over a neighboring or occluding object. Do NOT decide mask + ownership from where the number label is drawn. + - Decide whether a candidate is correct from the region enclosed by the colored + mask outline, not from the label position. + - Judge the object's real appearance from the underlying original image pixels + inside and around each outline. - Use the target object class name as the primary class. - Use the target description to distinguish visually similar objects from the same broad category. @@ -124,7 +134,12 @@ filter_extra_instances_user: | - Inspect the numbered-mask image. + Inspect the numbered-mask-outline image. + The colored outlines and number labels are annotations only, not real object + colors or textures. + The red number label is only an approximate index marker; judge each candidate + by the object region enclosed by that numbered outline, not by where the label + happens to be placed. Return the 1-based numbers of masks that should be removed so the remaining masks best match the target description and expected instance count. @@ -151,6 +166,12 @@ spatial_layout_system: | - x_order must be ordered from image/table left to image/table right. + - The camera view may be oblique, rotated, or perspective-distorted. Still + judge left/right from the visible tabletop image as a human would understand + the normal viewing direction of this image, not from object IDs, prompt + order, 3D generation order, or arbitrary simulation axes. + - The final left_of relations derived from x_order must reflect your visual + judgment from the attached image under that normal image/table view. - y_order must be ordered from table front to table back. - Split x_order groups when the left/right order is reasonably clear from the bbox-name image. @@ -236,3 +257,83 @@ spatial_layout_user: | Inspect the attached bbox-name image and return the JSON object. + +spatial_layout_verifier_system: | + + You are a strict verifier for tabletop spatial ordering. + + + + You will receive one bbox-name tabletop image, the full asset_id list, and a + draft spatial layout JSON. Verify whether the draft x_order and y_order are + visually correct. + + If the draft is correct, set passed=true and return the same layout in + corrected_layout. If any order is wrong or too specific/uncertain, set + passed=false and return a corrected complete spatial layout JSON in + corrected_layout. + + + + - x_order is the source for derived left_of relations. Check it carefully. + - The camera view may be oblique, rotated, or perspective-distorted. Still + judge left/right from the visible tabletop image as a human would understand + the normal viewing direction of this image, not from object IDs, prompt + order, 3D generation order, or arbitrary simulation axes. + - If a left/right relation is unclear, overlapping, occluded, or visually too + close to trust, place the objects in the same x_order group instead of + forcing an order. + - Check y_order more conservatively than x_order. If front/back is uncertain, + place objects in the same y_order group. + - Every asset_id must appear exactly once in corrected_layout.x_order, + corrected_layout.y_order, and corrected_layout.asset_states. + - Preserve or correct anchor and asset_states as needed, but focus primarily + on spatial order correctness. + - Return JSON only. + + + + { + "passed": false, + "reason": "The draft placed two overlapping objects in a forced left/right order, but the image does not support that relation.", + "corrected_layout": { + "anchor": { + "asset_id": "interact_paper_cup_0", + "grid": "center", + "reason": "The paper cup is clearly visible near the table center." + }, + "x_order": [ + ["interact_paper_cup_0", "interact_snack_bag_0"] + ], + "y_order": [ + ["interact_paper_cup_0", "interact_snack_bag_0"] + ], + "asset_states": [ + { + "asset_id": "interact_paper_cup_0", + "is_arbitrary_layout": false, + "reason": "The paper cup is standing upright and needs an upright support pose." + }, + { + "asset_id": "interact_snack_bag_0", + "is_arbitrary_layout": false, + "reason": "The snack bag should keep a deliberate lying or leaning support pose." + } + ] + } + } + + +spatial_layout_verifier_user: | + Verify this draft spatial layout for the detected object instances: + + + $asset_ids + + + + $draft_spatial_layout_json + + + Inspect the attached bbox-name image. Return whether the draft passes, the + reason, and a complete corrected_layout JSON. diff --git a/embodichain/gen_sim/prompt2scene/prompts/data/scene_intake.yaml b/embodichain/gen_sim/prompt2scene/prompts/data/scene_intake.yaml index bbdbbc8b0..d4dc81973 100644 --- a/embodichain/gen_sim/prompt2scene/prompts/data/scene_intake.yaml +++ b/embodichain/gen_sim/prompt2scene/prompts/data/scene_intake.yaml @@ -237,18 +237,22 @@ image_system: | - table.name, table.description, table.complete_table_description, table.class_candidate, and every asset.description must be non-empty. - Descriptions are used to generate images and then 3D geometry. - - Write each description as one concise English sentence, normally 8 to 20 - words. + - Write each image asset description as one detailed English sentence, normally + 18 to 35 words. - Every description must describe a SINGLE STANDALONE OBJECT isolated on a pure-white background. Do NOT mention any other object, the table, the scene, the room, or any background context. - Do NOT include any spatial, positional, or layout information such as "sitting on the table", "placed in front of", "to the left of", "on a surface", "on the tabletop", etc. - - When describing an object, first state what the object is, then mention - visible texture, color, shape, material, and similar appearance details. - - Keep descriptions simple. Focus only on what the object looks like, not - where it is or how it relates to anything else. + - When describing an object, first state what the object is, then describe its + visible texture, color, shape, material, surface finish, graphic markings, + openings, handles, caps, labels, and compound structure when visible. + - For composite objects, describe the visible parts and how they form one + standalone object, such as body, cap, lid, label, wrapper, handle, rim, base, + nozzle, straw, or attached accessory. + - Focus only on what the object itself looks like, not where it is or how it + relates to anything else. - For IMAGE inputs, include ONLY information supported by the image. Do NOT invent or embellish details not visible in the image. If a colour is ambiguous, use a reasonable neutral description ("light-colored", @@ -328,16 +332,16 @@ image_system: | }, "assets": [ { - "name": "apple", - "description": "A round apple with smooth red skin visible on the table.", - "class_candidate": ["apple", "fruit", "red_apple", "food", "produce"], + "name": "plastic_water_bottle", + "description": "A clear plastic water bottle with a ribbed cylindrical body, transparent glossy surface, narrow neck, blue screw cap, and printed paper label.", + "class_candidate": ["plastic_water_bottle", "water_bottle", "plastic_bottle", "bottle", "drink_container"], "count": 1 }, { - "name": "coffee_mug", - "description": "A white ceramic coffee mug with a curved handle.", - "class_candidate": ["coffee_mug", "ceramic_mug", "mug", "cup", "drinkware"], - "count": 2 + "name": "sports_bottle", + "description": "A matte dark sports bottle with a tapered body, textured grip band, rounded shoulder, flip-top cap, and solid opaque plastic construction.", + "class_candidate": ["sports_bottle", "water_bottle", "drink_bottle", "bottle", "container"], + "count": 1 } ] } @@ -466,9 +470,15 @@ verifier_system: | }, "assets": [ { - "name": "paper_cup", - "description": "A small white paper cup with blue printed details.", - "class_candidate": ["paper_cup", "disposable_cup", "cup", "drinkware", "container"], + "name": "plastic_water_bottle", + "description": "A clear plastic water bottle with a ribbed cylindrical body, transparent glossy surface, narrow neck, blue screw cap, and printed paper label.", + "class_candidate": ["plastic_water_bottle", "water_bottle", "plastic_bottle", "bottle", "drink_container"], + "count": 1 + }, + { + "name": "sports_bottle", + "description": "A matte dark sports bottle with a tapered body, textured grip band, rounded shoulder, flip-top cap, and solid opaque plastic construction.", + "class_candidate": ["sports_bottle", "water_bottle", "drink_bottle", "bottle", "container"], "count": 1 } ] diff --git a/embodichain/gen_sim/prompt2scene/prompts/schemas.py b/embodichain/gen_sim/prompt2scene/prompts/schemas.py index 20d617962..1c9893b93 100644 --- a/embodichain/gen_sim/prompt2scene/prompts/schemas.py +++ b/embodichain/gen_sim/prompt2scene/prompts/schemas.py @@ -29,6 +29,7 @@ "FILTER_EXTRA_INSTANCES_JSON_SCHEMA", "IMAGE_METRIC_SCALE_JSON_SCHEMA", "SCENE_INTAKE_JSON_SCHEMA", + "SPATIAL_LAYOUT_VERIFIER_JSON_SCHEMA", "SPATIAL_LAYOUT_JSON_SCHEMA", "TEXT_RELATIONS_JSON_SCHEMA", "UP_DOWN_FLIP_CHECK_JSON_SCHEMA", @@ -249,6 +250,25 @@ "required": ["anchor", "x_order", "y_order", "asset_states"], } +SPATIAL_LAYOUT_VERIFIER_JSON_SCHEMA: dict[str, Any] = { + "title": "ImageSpatialLayoutVerifierOutput", + "type": "object", + "additionalProperties": False, + "properties": { + "passed": { + "type": "boolean", + "description": "Whether the draft spatial layout is correct.", + }, + "reason": { + "type": "string", + "minLength": 1, + "description": "Concise verification reason.", + }, + "corrected_layout": SPATIAL_LAYOUT_JSON_SCHEMA, + }, + "required": ["passed", "reason", "corrected_layout"], +} + TEXT_RELATIONS_JSON_SCHEMA: dict[str, Any] = { "title": "TextRelationsOutput", diff --git a/embodichain/gen_sim/prompt2scene/workflows/gym_export.py b/embodichain/gen_sim/prompt2scene/workflows/gym_export.py index 2a68c5eb6..f1b575124 100644 --- a/embodichain/gen_sim/prompt2scene/workflows/gym_export.py +++ b/embodichain/gen_sim/prompt2scene/workflows/gym_export.py @@ -145,6 +145,23 @@ def _glb_scale_to_sim(scale: Sequence[float]) -> list[float]: return [values[0], values[2], values[1]] +def _decompose_affine_matrix(matrix_value: Any) -> tuple[list[float], list[float], list[float]]: + matrix = np.asarray(matrix_value, dtype=np.float64) + if matrix.shape != (4, 4): + raise ValueError("Expected a 4x4 affine matrix.") + linear = matrix[:3, :3] + scale = np.linalg.norm(linear, axis=0) + rotation = np.eye(3, dtype=np.float64) + for index in range(3): + if scale[index] > 1.0e-12: + rotation[:, index] = linear[:, index] / scale[index] + return ( + matrix[:3, 3].tolist(), + _matrix_to_euler_xyz_deg(rotation.tolist()), + scale.tolist(), + ) + + def _glb_max_z(glb_path: Path) -> float: """Maximum height (Y in GLB, Z in simulation) of a mesh.""" import trimesh @@ -215,6 +232,248 @@ def _rotated_aabb_offsets( ) +def _sim_world_xy_aabb( + glb_path: Path, + rotation_matrix: list[list[float]] | None, + scale: float | Sequence[float], + init_pos: Sequence[float], +) -> dict[str, Any]: + """Project a transformed simready GLB onto the simulation XY plane.""" + import trimesh + + scene = trimesh.load(glb_path, force="scene") + if isinstance(scene, trimesh.Trimesh): + mesh = scene + else: + dumped = scene.dump(concatenate=True) + mesh = ( + dumped + if isinstance(dumped, trimesh.Trimesh) + else trimesh.util.concatenate( + [m for m in dumped if isinstance(m, trimesh.Trimesh)] + ) + ) + verts = np.asarray(mesh.vertices.copy(), dtype=np.float64) + if isinstance(scale, Sequence) and not isinstance(scale, (str, bytes)): + scale_array = np.asarray(list(scale), dtype=np.float64) + if scale_array.shape != (3,): + raise ValueError("scale must be a scalar or a 3-vector") + verts *= scale_array + else: + verts *= float(scale) + if rotation_matrix is not None: + rot = np.asarray(rotation_matrix, dtype=np.float64) + if rot.shape == (4, 4): + rot = rot[:3, :3] + verts = (rot @ verts.T).T + + init = np.asarray(list(init_pos), dtype=np.float64) + if init.shape != (3,): + raise ValueError("init_pos must have three components") + sim_xy = np.column_stack((verts[:, 0] + init[0], -verts[:, 2] + init[1])) + min_xy = sim_xy.min(axis=0) + max_xy = sim_xy.max(axis=0) + center_xy = 0.5 * (min_xy + max_xy) + size_xy = np.maximum(max_xy - min_xy, 0.0) + return { + "unit": "m", + "center_xy": center_xy.tolist(), + "aabb_xy": [min_xy.tolist(), max_xy.tolist()], + "size_xy": size_xy.tolist(), + } + + +def _support_region_2d(table_fit_manifest: dict[str, Any]) -> dict[str, Any]: + support = table_fit_manifest.get("final_support_quad_centered") or {} + corners = np.asarray(support.get("corners_xy", []), dtype=np.float64) + if corners.shape != (4, 2): + return {"unit": "m", "center_xy": [], "aabb_xy": [], "size_xy": [], "corners_xy": []} + min_xy = corners.min(axis=0) + max_xy = corners.max(axis=0) + center_xy = np.asarray( + support.get("center_xy") or (0.5 * (min_xy + max_xy)).tolist(), + dtype=np.float64, + ) + size_xy = np.asarray( + support.get("size_xy") or (max_xy - min_xy).tolist(), + dtype=np.float64, + ) + return { + "unit": "m", + "center_xy": center_xy.tolist(), + "aabb_xy": [min_xy.tolist(), max_xy.tolist()], + "size_xy": size_xy.tolist(), + "corners_xy": corners.tolist(), + } + + +def _write_scene_state( + *, + export_dir: Path, + config_path: Path, + table_desc: str, + table_support_region_2d: dict[str, Any], + object_states: list[dict[str, Any]], + source_snapshots: dict[str, str], +) -> Path: + scene_state_dir = export_dir / "scene_state" + scene_state_dir.mkdir(parents=True, exist_ok=True) + plot_path = scene_state_dir / "topdown_2d.png" + state_path = scene_state_dir / "result.json" + state = { + "version": 1, + "coordinate_frame": { + "unit": "m", + "plane": "simulation_xy", + "x_axis": "simulation +X", + "y_axis": "simulation +Y", + "note": "2D values are top-down projections onto the simulation XY plane.", + }, + "gym_config_path": str(config_path.relative_to(export_dir)), + "topdown_2d_plot_path": str(plot_path.relative_to(export_dir)), + "source_snapshots": source_snapshots, + "table": { + "id": "table", + "role": "background", + "description": table_desc, + "support_region_2d": table_support_region_2d, + }, + "objects": object_states, + } + state_path.write_text( + json.dumps(state, indent=2, ensure_ascii=False) + "\n", + encoding="utf-8", + ) + _render_scene_state_topdown( + support_region=table_support_region_2d, + objects=object_states, + output_path=plot_path, + ) + return state_path + + +def _copy_scene_source_snapshots( + *, + paths: PipelinePaths, + export_dir: Path, + scene_state_dir: Path, +) -> dict[str, str]: + scene_state_dir.mkdir(parents=True, exist_ok=True) + snapshots: dict[str, str] = {} + sources = { + "unified_scene": paths.unified_scene_result, + "unified_scene_gen": paths.step_result(UNIFIED_SCENE_GEN_STEP), + } + for name, source in sources.items(): + if not source.is_file(): + continue + destination = scene_state_dir / f"{name}.json" + shutil.copy2(source, destination) + snapshots[name] = str(destination.relative_to(export_dir)) + return snapshots + + +def _render_scene_state_topdown( + *, + support_region: dict[str, Any], + objects: list[dict[str, Any]], + output_path: Path, +) -> None: + import matplotlib + + matplotlib.use("Agg") + import matplotlib.pyplot as plt + from matplotlib.patches import Polygon, Rectangle + + output_path.parent.mkdir(parents=True, exist_ok=True) + fig, ax = plt.subplots(figsize=(9, 9)) + + data_points: list[np.ndarray] = [] + corners = np.asarray(support_region.get("corners_xy", []), dtype=np.float64) + if corners.shape == (4, 2): + ax.add_patch( + Polygon( + corners, + closed=True, + facecolor=(0.18, 0.62, 0.32, 0.14), + edgecolor=(0.05, 0.38, 0.16, 1.0), + linewidth=2.0, + label="table support region", + ) + ) + data_points.append(corners) + + for obj in objects: + footprint = obj.get("footprint_2d") or {} + aabb = np.asarray(footprint.get("aabb_xy", []), dtype=np.float64) + center = np.asarray(footprint.get("center_xy", []), dtype=np.float64) + if aabb.shape != (2, 2) or center.shape != (2,): + continue + size = np.maximum(aabb[1] - aabb[0], 0.0) + ax.add_patch( + Rectangle( + aabb[0], + size[0], + size[1], + facecolor=(0.25, 0.48, 0.95, 0.22), + edgecolor=(0.08, 0.20, 0.65, 1.0), + linewidth=1.5, + ) + ) + ax.plot(center[0], center[1], "o", color="#102a7a", markersize=4) + label = str(obj.get("id", "")).replace("interact_", "") + ax.text( + center[0], + center[1], + f"{label}\n({center[0]:.3f}, {center[1]:.3f})", + ha="center", + va="center", + fontsize=8, + color="black", + ) + data_points.append(aabb) + + if data_points: + all_points = np.vstack(data_points) + data_min = all_points.min(axis=0) + data_max = all_points.max(axis=0) + else: + data_min = np.array([-0.5, -0.5], dtype=np.float64) + data_max = np.array([0.5, 0.5], dtype=np.float64) + span = np.maximum(data_max - data_min, 1.0e-3) + padding = max(float(span.max()) * 0.18, 0.05) + x_limits = (float(data_min[0] - padding), float(data_max[0] + padding)) + y_limits = (float(data_min[1] - padding), float(data_max[1] + padding)) + + ax.axhline(0.0, color="#303030", linewidth=1.2, alpha=0.75) + ax.axvline(0.0, color="#303030", linewidth=1.2, alpha=0.75) + ax.annotate( + "+X", + xy=(x_limits[1], 0.0), + xytext=(x_limits[1] - 0.08 * (x_limits[1] - x_limits[0]), 0.02), + arrowprops={"arrowstyle": "->", "color": "#303030", "lw": 1.4}, + color="#303030", + ) + ax.annotate( + "+Y", + xy=(0.0, y_limits[1]), + xytext=(0.02, y_limits[1] - 0.08 * (y_limits[1] - y_limits[0])), + arrowprops={"arrowstyle": "->", "color": "#303030", "lw": 1.4}, + color="#303030", + ) + ax.set_xlim(*x_limits) + ax.set_ylim(*y_limits) + ax.set_aspect("equal", adjustable="box") + ax.set_xlabel("X (m)") + ax.set_ylabel("Y (m)") + ax.set_title("Prompt2Scene Gym Export Top-Down 2D State") + ax.grid(True, linestyle=":", linewidth=0.6, alpha=0.45) + ax.legend(loc="upper right") + fig.tight_layout() + fig.savefig(output_path, dpi=180, facecolor="white") + plt.close(fig) + + def _build_object_manifest( output_root: Path, step_result: dict[str, Any], @@ -336,62 +595,84 @@ def export_gym_config( ) ) + aligned_by_id: dict[str, dict[str, Any]] = {} + if paths.simready_to_aligned_manifest.is_file(): + for item in _read_json(paths.simready_to_aligned_manifest).get("items", []) or []: + if isinstance(item, dict) and item.get("id"): + aligned_by_id[str(item["id"])] = item + + object_manifest = _build_object_manifest( + output_root, step_result, table_fit_manifest, aligned_by_id + ) + table_info = step_result.get("table") or {} table_desc = str( table_info.get("complete_table_description") or table_info.get("description", "") ).strip() - object_desc_by_id = { - str(item.get("id", "")): str( - item.get("description") or item.get("name") or "" - ).strip() - for item in step_result.get("objects") or [] - if isinstance(item, dict) and item.get("id") - } mesh_assets_dir = export_dir / "mesh_assets" mesh_assets_dir.mkdir(parents=True, exist_ok=True) - table_fit_output = _resolve_path( - table_fit_manifest.get("table_output_path", ""), + table_simready = _resolve_path( + table_info.get("simready_geometry_path") + or table_info.get("mesh_path", ""), output_root, ) - if not table_fit_output.is_file(): - raise FileNotFoundError(f"Table-fit GLB not found: {table_fit_output}") + if not table_simready.is_file(): + raise FileNotFoundError(f"Table simready GLB not found: {table_simready}") table_dst = mesh_assets_dir / "table" / "table_0.glb" table_dst.parent.mkdir(parents=True, exist_ok=True) - shutil.copy2(table_fit_output, table_dst) + shutil.copy2(table_simready, table_dst) + + table_fit_transform = table_fit_manifest.get("table_fit_transform") + if table_fit_transform: + table_init_pos, table_init_rot, table_body_scale = _decompose_affine_matrix( + table_fit_transform + ) + else: + uniform_scale = 1.0 + ts = table_fit_manifest.get("table_xy_scale") + if isinstance(ts, dict): + uniform_scale = float(ts.get("uniform_scale", 1.0)) + table_init_pos = [0.0, 0.0, 0.0] + table_init_rot = [0.0, 0.0, 0.0] + table_body_scale = [uniform_scale, uniform_scale, 1.0] rigid_objects: list[dict[str, Any]] = [] + object_states: list[dict[str, Any]] = [] - fitted_objects = [ - item - for item in table_fit_manifest.get("objects", []) or [] - if isinstance(item, dict) and item.get("id") and item.get("path") - ] - total = len(fitted_objects) - for idx, item in enumerate(fitted_objects): - oid = str(item["id"]) + total = len(object_manifest) + for idx, (oid, om) in enumerate(object_manifest.items()): safe_name = oid.replace("interact_", "").strip("_") or "object" obj_dir = mesh_assets_dir / safe_name / oid obj_dir.mkdir(parents=True, exist_ok=True) object_dst = obj_dir / f"{oid}.glb" - object_fit_path = _resolve_path(str(item["path"]), output_root) - if not object_fit_path.is_file(): - raise FileNotFoundError(f"Table-fit object GLB not found: {object_fit_path}") - shutil.copy2(object_fit_path, object_dst) - - # Table-fit GLBs already have the relative layout baked into vertices. - # Preview/export should not rebuild placement from simready transforms. - init_pos = [0.0, 0.0, 0.0] - init_rot = [0.0, 0.0, 0.0] - body_scale = [1.0, 1.0, 1.0] - description = object_desc_by_id.get(oid, oid) + shutil.copy2(om["simready_path"], object_dst) + + sf = om["scale_factor"] + scale_glb = om.get("transform_scale") or [sf, sf, sf] + body_scale = _glb_scale_to_sim(scale_glb) + + init_rot: list[float] = [0.0, 0.0, 0.0] + if om["rotation_matrix"] is not None: + init_rot = _matrix_to_euler_xyz_deg( + _glb_rotation_to_sim(om["rotation_matrix"]) + ) + + ro = _rotated_aabb_offsets( + om["simready_path"], om["rotation_matrix"], scale_glb + ) + wbc = om["world_aabb_bottom_center"] + if wbc is not None: + init_pos = [wbc[0] - ro[0], wbc[1] - ro[1], wbc[2] - ro[2]] + else: + raise ValueError(f"Missing table-fit world_aabb_bottom_center for {oid}") rigid_objects.append( { "uid": oid, - "description": description, + "description": om["description"], "shape": { "shape_type": "Mesh", "fpath": str(object_dst.relative_to(export_dir)), @@ -405,9 +686,28 @@ def export_gym_config( "max_convex_hull_num": _DEFAULT_MAX_CONVEX_HULL_NUM, } ) + footprint_2d = _sim_world_xy_aabb( + om["simready_path"], + om["rotation_matrix"], + scale_glb, + init_pos, + ) + object_states.append( + { + "id": oid, + "name": safe_name, + "role": "interact", + "description": om["description"], + "init_pos": init_pos, + "init_rot": init_rot, + "body_scale": body_scale, + "footprint_2d": footprint_2d, + } + ) + wbc_flag = "wbc" if wbc is not None else "missing_wbc" print( - f" [{idx+1}/{total}] [{oid}] {description}" - f" pos={init_pos} rot={init_rot} scale={body_scale} src=table_fit_glb" + f" [{idx+1}/{total}] [{oid}] {om['description']}" + f" pos={init_pos} rot={init_rot} scale={body_scale} src={wbc_flag}" ) config = { @@ -428,10 +728,10 @@ def export_gym_config( "compute_uv": False, }, "attrs": dict(_DEFAULT_TABLE_ATTRS), - "body_scale": [1.0, 1.0, 1.0], + "body_scale": table_body_scale, "body_type": "kinematic", - "init_pos": [0.0, 0.0, 0.0], - "init_rot": [0.0, 0.0, 0.0], + "init_pos": table_init_pos, + "init_rot": table_init_rot, } ], "rigid_object": rigid_objects, @@ -442,5 +742,20 @@ def export_gym_config( json.dumps(config, indent=4, ensure_ascii=False) + "\n", encoding="utf-8", ) + scene_state_dir = export_dir / "scene_state" + source_snapshots = _copy_scene_source_snapshots( + paths=paths, + export_dir=export_dir, + scene_state_dir=scene_state_dir, + ) + scene_state_path = _write_scene_state( + export_dir=export_dir, + config_path=config_path, + table_desc=table_desc, + table_support_region_2d=_support_region_2d(table_fit_manifest), + object_states=object_states, + source_snapshots=source_snapshots, + ) + print(f" scene_state={scene_state_path.relative_to(export_dir)}") return config_path diff --git a/embodichain/gen_sim/prompt2scene/workflows/image_relations/nodes.py b/embodichain/gen_sim/prompt2scene/workflows/image_relations/nodes.py index fbaad0e50..2c36868c2 100644 --- a/embodichain/gen_sim/prompt2scene/workflows/image_relations/nodes.py +++ b/embodichain/gen_sim/prompt2scene/workflows/image_relations/nodes.py @@ -20,7 +20,6 @@ from typing import Any from embodichain.gen_sim.prompt2scene.agent_tools.tools.image_segment_filter import ( - filter_group_segments_with_vlm, filter_segments_with_vlm, ) from embodichain.gen_sim.prompt2scene.workflows.image_relations.schema import ( @@ -47,6 +46,7 @@ asset_bbox_label, draw_labeled_bboxes, expand_asset_ids, + filter_group_segments_with_artifacts, merge_non_overlapping_segments, prompt_text, path_token, @@ -56,6 +56,7 @@ segments_from_response, select_largest_table_segment, table_segmentation_prompts, + verify_spatial_layout_output, write_table_candidate_debug_image, ) from embodichain.gen_sim.prompt2scene.prompts.builders import ( @@ -83,6 +84,7 @@ "segment_by_name_node", ] + def prepare_segmentation_input_node(state: ImageRelationsState) -> dict[str, object]: """Prepare scene-intake asset groups for class-level segmentation.""" request = state["request"] @@ -167,6 +169,11 @@ def retry_missing_by_candidates_node( group = dict(group) segments = group["segments"] expected_count = group["expected_count"] + confirmed_segments = [ + segment + for existing_group in state["segment_groups"] + for segment in existing_group.get("segments", []) + ] for candidate_name in group["class_candidate"][1:]: if len(segments) >= expected_count: break @@ -181,26 +188,21 @@ def retry_missing_by_candidates_node( source_prompt=prompt, ) stage_label = f"fallback_{path_token(prompt)}" - round_name_inner = artifact_writer.next_debug_round_name( - label=f"{stage_label}_{group['name']}" - ) - round_dir_inner = artifact_writer.debug_round_dir(round_name_inner) - new_segments = filter_group_segments_with_vlm( + new_segments = filter_group_segments_with_artifacts( llm=llm, image_path=image_path, - step_name=IMAGE_SEGMENTS_STEP, + artifact_writer=artifact_writer, group=group, segments=new_segments, stage=stage_label, - debug_round_name=round_name_inner, - debug_round_dir=round_dir_inner, - write_debug_json=artifact_writer.write_debug_round_json, + confirmed_segments=confirmed_segments, ) segments = merge_non_overlapping_segments( existing=segments, incoming=new_segments, limit=expected_count, ) + confirmed_segments = confirmed_segments + new_segments if len(segments) < expected_count: description_prompt = str(group.get("description") or "").strip() if description_prompt and description_prompt not in group["tried_prompts"]: @@ -217,19 +219,21 @@ def retry_missing_by_candidates_node( response=response, source_prompt=description_prompt, ) - new_segments = filter_group_segments_with_vlm( + new_segments = filter_group_segments_with_artifacts( llm=llm, image_path=image_path, artifact_writer=artifact_writer, group=group, segments=new_segments, stage="fallback_description", + confirmed_segments=confirmed_segments, ) segments = merge_non_overlapping_segments( existing=segments, incoming=new_segments, limit=expected_count, ) + confirmed_segments = confirmed_segments + new_segments group["segments"] = segments segment_groups.append(group) return {"segment_groups": segment_groups} @@ -453,11 +457,22 @@ def call_vlm_spatial_layout_node( ) + verifier_output = verify_spatial_layout_output( + llm=llm, + bbox_name_image_path=Path(image_relations.bbox_name_image_path), + asset_ids=asset_ids, + raw_model_output=raw_model_output, + attempt_count=attempt_count, + artifact_writer=artifact_writer, + ) + verified_model_output = verifier_output["corrected_layout"] updated_image_relations = apply_spatial_layout_output( image_relations=image_relations, - raw_model_output=raw_model_output, + raw_model_output=verified_model_output, ) - artifact_writer.write_step_result(updated_image_relations.to_spatial_manifest()) + spatial_manifest = updated_image_relations.to_spatial_manifest() + spatial_manifest["spatial_layout_verifier"] = verifier_output + artifact_writer.write_step_result(spatial_manifest) except Exception as exc: if is_model_output_error(exc) or isinstance(exc, ValueError): error = format_attempt_error( diff --git a/embodichain/gen_sim/prompt2scene/workflows/image_relations/utils.py b/embodichain/gen_sim/prompt2scene/workflows/image_relations/utils.py index 5a7070832..0644cea52 100644 --- a/embodichain/gen_sim/prompt2scene/workflows/image_relations/utils.py +++ b/embodichain/gen_sim/prompt2scene/workflows/image_relations/utils.py @@ -16,6 +16,7 @@ from __future__ import annotations +import json from pathlib import Path from typing import Any @@ -31,8 +32,12 @@ is_usable_segmentation_candidate, sort_segments_by_bbox, ) +from embodichain.gen_sim.prompt2scene.agent_tools.tools.image_segment_filter import ( + filter_group_segments_with_vlm, +) from embodichain.gen_sim.prompt2scene.prompts.schemas import ( SPATIAL_LAYOUT_JSON_SCHEMA, + SPATIAL_LAYOUT_VERIFIER_JSON_SCHEMA, ) from embodichain.gen_sim.prompt2scene.workflows.image_relations.schema import ( ImageAnchor, @@ -53,6 +58,7 @@ ) from embodichain.gen_sim.prompt2scene.prompts.builders import ( build_spatial_layout_messages, + build_spatial_layout_verifier_messages, ) from embodichain.gen_sim.prompt2scene.llms.llm_output import ( call_structured_json_model_step, @@ -66,6 +72,7 @@ "asset_bbox_label", "draw_labeled_bboxes", "expand_asset_ids", + "filter_group_segments_with_artifacts", "merge_non_overlapping_segments", "parse_anchor", "parse_asset_states", @@ -79,6 +86,7 @@ "select_largest_table_segment", "sort_segments_by_bbox", "table_segmentation_prompts", + "verify_spatial_layout_output", "write_table_candidate_debug_image", ] @@ -213,6 +221,54 @@ def apply_spatial_layout_output( ) +def verify_spatial_layout_output( + *, + llm: Any, + bbox_name_image_path: Path, + asset_ids: list[str], + raw_model_output: dict[str, Any], + attempt_count: int, + artifact_writer: WorkflowArtifactWriter, +) -> dict[str, Any]: + """Verify and optionally rewrite spatial layout VLM output.""" + messages = build_spatial_layout_verifier_messages( + bbox_name_image_path=bbox_name_image_path, + asset_ids=asset_ids, + draft_spatial_layout_json=json.dumps( + raw_model_output, + ensure_ascii=False, + indent=2, + ), + ) + log_api_request_start( + step=IMAGE_SPATIAL_RELATIONS_STEP, + request="spatial_layout_verify", + attempt=attempt_count, + ) + round_name = artifact_writer.next_debug_round_name("spatial_layout_verify") + verifier_output = call_structured_json_model_step( + llm=llm, + schema=SPATIAL_LAYOUT_VERIFIER_JSON_SCHEMA, + messages=messages, + context="Image spatial layout verifier", + attempt_count=attempt_count, + raw_output_writer=lambda payload: artifact_writer.write_debug_round_json( + round_name=round_name, + filename="raw_model_output.json", + payload=payload, + ), + ) + artifact_writer.write_debug_round_json( + round_name=round_name, + filename="verifier_result.json", + payload=verifier_output, + ) + corrected = verifier_output.get("corrected_layout") + if not isinstance(corrected, dict): + raise ValueError("spatial_layout_verifier.corrected_layout must be an object.") + return verifier_output + + def parse_anchor(raw_anchor: Any, *, asset_id_set: set[str]) -> ImageAnchor: """Parse and validate the anchor entry.""" if not isinstance(raw_anchor, dict): @@ -336,6 +392,34 @@ def write_table_candidate_debug_image( group["debug_images"] = debug_images +def filter_group_segments_with_artifacts( + *, + llm: Any, + image_path: Path, + artifact_writer: WorkflowArtifactWriter, + group: dict[str, Any], + segments: list[dict[str, Any]], + stage: str, + confirmed_segments: list[dict[str, Any]] | None = None, +) -> list[dict[str, Any]]: + """Filter one group while keeping workflow artifact handling out of nodes.""" + round_name = artifact_writer.next_debug_round_name( + label=f"{stage}_{group['name']}" + ) + return filter_group_segments_with_vlm( + llm=llm, + image_path=image_path, + step_name=IMAGE_SEGMENTS_STEP, + group=group, + segments=segments, + stage=stage, + debug_round_name=round_name, + debug_round_dir=artifact_writer.debug_round_dir(round_name), + write_debug_json=artifact_writer.write_debug_round_json, + confirmed_segments=confirmed_segments, + ) + + def select_largest_table_segment( segments: list[dict[str, Any]], ) -> dict[str, Any] | None: diff --git a/embodichain/gen_sim/prompt2scene/workflows/scene_intake/utils.py b/embodichain/gen_sim/prompt2scene/workflows/scene_intake/utils.py index da084f559..3cd5405ff 100644 --- a/embodichain/gen_sim/prompt2scene/workflows/scene_intake/utils.py +++ b/embodichain/gen_sim/prompt2scene/workflows/scene_intake/utils.py @@ -206,20 +206,37 @@ def _parse_class_candidate( raise ValueError( f"Scene intake asset {asset_index}.class_candidate must be a list." ) - class_candidate = [normalize_asset_name(str(item)) for item in raw_class_candidate] - if len(class_candidate) != 5: + class_candidate = [ + normalize_asset_name(str(item)) + for item in raw_class_candidate + if normalize_asset_name(str(item)) + ] + expected_name = normalize_asset_name(raw_name) + normalized_candidates = [expected_name] + for candidate in class_candidate: + if candidate != expected_name and candidate not in normalized_candidates: + normalized_candidates.append(candidate) + generic_fallbacks = [ + "object", + "item", + "container", + "tableware", + "household_object", + ] + for fallback in generic_fallbacks: + if len(normalized_candidates) >= 5: + break + if fallback != expected_name and fallback not in normalized_candidates: + normalized_candidates.append(fallback) + if len(normalized_candidates) != 5: raise ValueError( f"Scene intake asset {asset_index}.class_candidate must contain exactly five entries." ) - if any(not candidate for candidate in class_candidate): + if any(not candidate for candidate in normalized_candidates): raise ValueError( f"Scene intake asset {asset_index}.class_candidate has empty entries." ) - if class_candidate[0] != normalize_asset_name(raw_name): - raise ValueError( - f"Scene intake asset {asset_index}.class_candidate[0] must equal name." - ) - return class_candidate + return normalized_candidates def _parse_count(raw_count: Any, *, asset_index: int) -> int: From 0da9dddce36ce9b383bdbe703e6b91316e3538ef Mon Sep 17 00:00:00 2001 From: Muzi Wong <178915912+MuziWong@users.noreply.github.com> Date: Wed, 1 Jul 2026 22:10:15 +0800 Subject: [PATCH 7/7] Finished scene editing; --- .../managers/layout_manager/utils.py | 13 + embodichain/gen_sim/prompt2scene/cli/start.py | 22 +- .../gen_sim/prompt2scene/pipeline/runner.py | 40 +- .../gen_sim/prompt2scene/prompts/builders.py | 33 + .../prompt2scene/prompts/data/scene_edit.yaml | 129 ++ .../gen_sim/prompt2scene/prompts/schemas.py | 189 +- .../prompt2scene/workflows/__init__.py | 2 + .../prompt2scene/workflows/artifact_writer.py | 2 + .../gen_sim/prompt2scene/workflows/paths.py | 6 + .../gen_sim/prompt2scene/workflows/request.py | 42 +- .../workflows/scene_edit/__init__.py | 23 + .../workflows/scene_edit/graph.py | 108 ++ .../workflows/scene_edit/nodes.py | 326 ++++ .../workflows/scene_edit/schema.py | 52 + .../workflows/scene_edit/utils.py | 1554 +++++++++++++++++ 15 files changed, 2531 insertions(+), 10 deletions(-) create mode 100644 embodichain/gen_sim/prompt2scene/prompts/data/scene_edit.yaml create mode 100644 embodichain/gen_sim/prompt2scene/workflows/scene_edit/__init__.py create mode 100644 embodichain/gen_sim/prompt2scene/workflows/scene_edit/graph.py create mode 100644 embodichain/gen_sim/prompt2scene/workflows/scene_edit/nodes.py create mode 100644 embodichain/gen_sim/prompt2scene/workflows/scene_edit/schema.py create mode 100644 embodichain/gen_sim/prompt2scene/workflows/scene_edit/utils.py diff --git a/embodichain/gen_sim/prompt2scene/agent_tools/managers/layout_manager/utils.py b/embodichain/gen_sim/prompt2scene/agent_tools/managers/layout_manager/utils.py index a4b2dde39..ec4879f14 100644 --- a/embodichain/gen_sim/prompt2scene/agent_tools/managers/layout_manager/utils.py +++ b/embodichain/gen_sim/prompt2scene/agent_tools/managers/layout_manager/utils.py @@ -10,6 +10,7 @@ from typing import Any import numpy as np +from scipy.optimize import minimize from embodichain.gen_sim.prompt2scene.agent_tools.managers.simulation_manager import ( SimulationManager, @@ -33,6 +34,18 @@ "_xy_union_bounds", ] +_WEIGHTS: dict[str, float] = { + "seed": 1.0, + "overlap": 200.0, + "grid": 3.0, +} + +_SLSQP_OPTIONS: dict[str, Any] = { + "maxiter": 300, + "ftol": 1.0e-6, + "disp": False, +} + def _object_scenes_xy_aabb_manifest( *, object_scenes: list[tuple[str, Any]], diff --git a/embodichain/gen_sim/prompt2scene/cli/start.py b/embodichain/gen_sim/prompt2scene/cli/start.py index fdc3a27b5..6468628a4 100644 --- a/embodichain/gen_sim/prompt2scene/cli/start.py +++ b/embodichain/gen_sim/prompt2scene/cli/start.py @@ -29,6 +29,7 @@ def cli_prompt2scene( image_path: str | None, text: str | None, + prompt: str | None, output_root: str, llm_config_path: str | None = None, ) -> None: @@ -37,12 +38,14 @@ def cli_prompt2scene( Args: image_path: Path to an input image, if image mode is used. text: Text prompt, if text mode is used. + prompt: Optional edit prompt. output_root: Directory where prompt2scene outputs are written. llm_config_path: Optional path to the LLM config JSON file. """ request = Prompt2SceneInput.from_cli_args( image_path=Path(image_path) if image_path is not None else None, text=text, + prompt=prompt, output_root=Path(output_root), ) llm_cfg = load_llm_config( @@ -57,7 +60,7 @@ def main() -> None: description="embodichain.gen_sim.prompt2scene Prompt-to-Scene Pipeline" ) - input_group = parser.add_mutually_exclusive_group(required=True) + input_group = parser.add_mutually_exclusive_group(required=False) input_group.add_argument( "--image", type=str, @@ -68,6 +71,15 @@ def main() -> None: type=str, help="Text prompt describing the target scene", ) + parser.add_argument( + "--prompt", + type=str, + default=None, + help=( + "Optional edit instruction. Use with --image/--text to edit after " + "generation, or with only --output_root to edit an existing scene." + ), + ) parser.add_argument( "--output_root", type=str, @@ -83,7 +95,13 @@ def main() -> None: args = parser.parse_args() - cli_prompt2scene(args.image, args.text, args.output_root, args.llm_config) + cli_prompt2scene( + args.image, + args.text, + args.prompt, + args.output_root, + args.llm_config, + ) if __name__ == "__main__": diff --git a/embodichain/gen_sim/prompt2scene/pipeline/runner.py b/embodichain/gen_sim/prompt2scene/pipeline/runner.py index 90d788c3e..1b6f12700 100644 --- a/embodichain/gen_sim/prompt2scene/pipeline/runner.py +++ b/embodichain/gen_sim/prompt2scene/pipeline/runner.py @@ -27,6 +27,7 @@ from embodichain.gen_sim.prompt2scene.workflows.paths import ( IMAGE_SEGMENTS_STEP, IMAGE_SPATIAL_RELATIONS_STEP, + SCENE_EDIT_STEP, SCENE_INTAKE_STEP, TEXT_RELATIONS_STEP, UNIFIED_SCENE_STEP, @@ -44,6 +45,10 @@ from embodichain.gen_sim.prompt2scene.workflows.gym_export import ( export_gym_config, ) +from embodichain.gen_sim.prompt2scene.workflows.scene_edit import run_scene_edit +from embodichain.gen_sim.prompt2scene.workflows.scene_edit.schema import ( + SceneEditRequest, +) from embodichain.gen_sim.prompt2scene.utils.io import write_json from embodichain.gen_sim.prompt2scene.utils import log from embodichain.gen_sim.prompt2scene.workflows.image_relations import ( @@ -58,6 +63,7 @@ "IMAGE_SEGMENTS_DIRNAME", "IMAGE_SPATIAL_RELATIONS_DIRNAME", "INPUT_MANIFEST_FILENAME", + "SCENE_EDIT_DIRNAME", "SCENE_INTAKE_DIRNAME", "STEP_RESULT_FILENAME", "TEXT_RELATIONS_DIRNAME", @@ -68,6 +74,7 @@ INPUT_MANIFEST_FILENAME = "input_manifest.json" SCENE_INTAKE_DIRNAME = SCENE_INTAKE_STEP +SCENE_EDIT_DIRNAME = SCENE_EDIT_STEP IMAGE_SEGMENTS_DIRNAME = IMAGE_SEGMENTS_STEP IMAGE_SPATIAL_RELATIONS_DIRNAME = IMAGE_SPATIAL_RELATIONS_STEP TEXT_RELATIONS_DIRNAME = TEXT_RELATIONS_STEP @@ -86,6 +93,8 @@ class Prompt2SceneRunResult: image_spatial_relations_path: Path to serialized image spatial relations. text_relations_path: Path to serialized text spatial relations. unified_scene_path: Path to serialized unified scene output. + gym_config_path: Path to the exported gym config. + scene_edit_path: Path to serialized scene edit output. """ output_root: Path @@ -96,6 +105,7 @@ class Prompt2SceneRunResult: text_relations_path: Path | None = None unified_scene_path: Path | None = None gym_config_path: Path | None = None + scene_edit_path: Path | None = None def run_prompt2scene( @@ -132,7 +142,21 @@ def run_prompt2scene( text_relations_path = None unified_scene_path = None gym_config_path = None - if llm_cfg is not None: + scene_edit_path = None + if request.input_kind == InputKind.EDIT: + log.log_info("step start scene_edit") + run_scene_edit( + SceneEditRequest( + output_root=request.output_root, + prompt=request.prompt or "", + ), + llm_cfg=llm_cfg, + ) + scene_edit_path = paths.step_result(SCENE_EDIT_STEP) + log.log_info( + f"step end scene_edit status=pending_implementation output={scene_edit_path}" + ) + elif llm_cfg is not None: log.log_info("step start scene_intake") scene_intake = run_scene_intake(request, llm_cfg=llm_cfg) scene_intake_path = write_step_result( @@ -221,6 +245,19 @@ def run_prompt2scene( log.log_info("step start gym_export") gym_config_path = export_gym_config(request.output_root) log.log_info(f"step end gym_export status=ok output={gym_config_path}") + if request.prompt: + log.log_info("step start scene_edit") + run_scene_edit( + SceneEditRequest( + output_root=request.output_root, + prompt=request.prompt, + ), + llm_cfg=llm_cfg, + ) + scene_edit_path = paths.step_result(SCENE_EDIT_STEP) + log.log_info( + f"step end scene_edit status=pending_implementation output={scene_edit_path}" + ) log.log_info(f"run end output_root={request.output_root}") @@ -233,4 +270,5 @@ def run_prompt2scene( text_relations_path=text_relations_path, unified_scene_path=unified_scene_path, gym_config_path=gym_config_path, + scene_edit_path=scene_edit_path, ) diff --git a/embodichain/gen_sim/prompt2scene/prompts/builders.py b/embodichain/gen_sim/prompt2scene/prompts/builders.py index 99018e329..fcdd95d1b 100644 --- a/embodichain/gen_sim/prompt2scene/prompts/builders.py +++ b/embodichain/gen_sim/prompt2scene/prompts/builders.py @@ -28,6 +28,7 @@ "build_image_metric_scale_messages", "build_scene_intake_messages", "build_scene_intake_verifier_messages", + "build_scene_edit_intent_messages", "build_spatial_layout_messages", "build_spatial_layout_verifier_messages", "build_text_metric_scale_messages", @@ -37,6 +38,7 @@ SCENE_INTAKE_PROMPT = "scene_intake.yaml" +SCENE_EDIT_PROMPT = "scene_edit.yaml" IMAGE_RELATIONS_PROMPT = "image_relations.yaml" TEXT_RELATIONS_PROMPT = "text_relations.yaml" UNIFIED_SCENE_GEN_PROMPT = "unified_scene_gen.yaml" @@ -303,6 +305,37 @@ def build_spatial_layout_verifier_messages( ] +def build_scene_edit_intent_messages( + *, + prompt: str, + scene_objects: list[dict[str, Any]], + current_relations: list[dict[str, Any]], +) -> list[dict[str, Any]]: + """Build messages for editing an existing scene from a user prompt.""" + return [ + { + "role": "system", + "content": render_prompt( + SCENE_EDIT_PROMPT, + prompt_key="intent_system", + ), + }, + { + "role": "user", + "content": render_prompt( + SCENE_EDIT_PROMPT, + { + "prompt": prompt, + "scene_objects_json": json.dumps( + scene_objects, ensure_ascii=False, indent=2 + ), + }, + prompt_key="intent_user", + ), + }, + ] + + def build_text_relation_messages( diff --git a/embodichain/gen_sim/prompt2scene/prompts/data/scene_edit.yaml b/embodichain/gen_sim/prompt2scene/prompts/data/scene_edit.yaml new file mode 100644 index 000000000..1b1a0a025 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/prompts/data/scene_edit.yaml @@ -0,0 +1,129 @@ +intent_system: | + You are a scene-edit intent planner for a tabletop simulation scene. + + You receive: + - A user edit prompt. + - Existing scene objects. Each object has only id, name, and description. + + Your job is to produce a structured edit intent. Do not generate geometry. + + Rules: + - Existing objects must be referenced only by their exact object id. + - deleted_object_ids may contain only ids from the existing scene objects. + - A replacement is represented as deleting the target object and adding one generated object. + - Every added or replacement object must appear in generated_objects with a temp_id, canonical snake_case name, and standalone detailed description suitable for text-to-geometry simready generation. + - For delete operations, identify only the existing object ids to delete. + - For replacement operations, set type=replace, set target_object_id to the existing object being replaced, create exactly one generated object, and set placement.type=preserve_target unless the user gives an explicit new placement. + - For add operations, infer placement from the user prompt when possible. + - Layout may be expressed either as a relation to an existing object or as a 9-grid location. + - A 9-grid location means the object's center or main tabletop footprint should be placed in one of the nine regions of the table support region split in XY: center, front, back, left_center, right_center, left_front, right_front, left_back, right_back. It is not a grid on the object itself. + - For relative layout, use placement.type=relative_to_object, exact reference_object_id, and one of left_of, right_of, front_of, or back_of. + - For 9-grid layout, use placement.type=grid and one of: center, front, back, left_center, right_center, left_front, right_front, left_back, right_back. + - Do not compute the final relation closure yourself. The program will remove deleted ids, inherit replacement layout, normalize right/back, preserve unaffected relations, add new layout, and compute transitive closure. + - Do not invent a target object when the prompt is ambiguous. Put the ambiguity in unresolved. + - If an operation cannot be grounded to an existing target or reference object, keep it unresolved instead of guessing. + - The top-level key reason is required. + - Every generated object must include source_operation. + - Every operation must include reason and confidence. + - Every unresolved item must include reason. + - Output JSON only. + + Example 1: + Existing scene objects: + [ + { + "id": "interact_plastic_water_bottle_0", + "name": "plastic_water_bottle", + "description": "A clear plastic water bottle with a white cap and a green label." + }, + { + "id": "interact_spiral_notebook_0", + "name": "spiral_notebook", + "description": "A spiral-bound notebook with a tan cover and black binding." + } + ] + User prompt: + replace the water bottle with a big red apple + Output: + { + "deleted_object_ids": ["interact_plastic_water_bottle_0"], + "generated_objects": [ + { + "temp_id": "new_red_apple_0", + "name": "red_apple", + "description": "A large red apple with smooth glossy skin, a round body, and a short brown stem.", + "source_operation": "replace" + } + ], + "operations": [ + { + "type": "replace", + "target_object_id": "interact_plastic_water_bottle_0", + "new_object_temp_id": "new_red_apple_0", + "placement": { + "type": "preserve_target", + "reference_object_id": "", + "relation": "", + "grid": "" + }, + "reason": "The user explicitly replaces the water bottle with a new apple object.", + "confidence": 0.98 + } + ], + "unresolved": [], + "reason": "Replace the water bottle and keep its original placement context." + } + + Example 2: + Existing scene objects: + [ + { + "id": "interact_spiral_notebook_0", + "name": "spiral_notebook", + "description": "A spiral-bound notebook with a tan cover and black binding." + } + ] + User prompt: + add a blue mug to the left front of the table + Output: + { + "deleted_object_ids": [], + "generated_objects": [ + { + "temp_id": "new_blue_mug_0", + "name": "blue_mug", + "description": "A ceramic blue mug with a glossy finish, a rounded body, and a side handle.", + "source_operation": "add" + } + ], + "operations": [ + { + "type": "add", + "target_object_id": "", + "new_object_temp_id": "new_blue_mug_0", + "placement": { + "type": "grid", + "reference_object_id": "", + "relation": "", + "grid": "left_front" + }, + "reason": "The user asks to add a new mug in a specific table 9-grid region.", + "confidence": 0.96 + } + ], + "unresolved": [], + "reason": "Add a new mug at the left_front region of the table support area." + } + +intent_user: | + User edit prompt: + ${prompt} + + Existing scene objects: + ${scene_objects_json} + + Produce the scene edit intent. Remember: + - deleted_object_ids includes both explicitly deleted objects and replaced old objects. + - generated_objects includes every new object needed by add or replace operations. + - operations must contain enough placement information for the program to update relations and 9-grid assignments. + - The required keys must all be present, including top-level reason, generated_objects[].source_operation, operations[].reason, operations[].confidence, and unresolved[].reason when unresolved is non-empty. diff --git a/embodichain/gen_sim/prompt2scene/prompts/schemas.py b/embodichain/gen_sim/prompt2scene/prompts/schemas.py index 1c9893b93..4cebaacba 100644 --- a/embodichain/gen_sim/prompt2scene/prompts/schemas.py +++ b/embodichain/gen_sim/prompt2scene/prompts/schemas.py @@ -29,6 +29,7 @@ "FILTER_EXTRA_INSTANCES_JSON_SCHEMA", "IMAGE_METRIC_SCALE_JSON_SCHEMA", "SCENE_INTAKE_JSON_SCHEMA", + "SCENE_EDIT_INTENT_JSON_SCHEMA", "SPATIAL_LAYOUT_VERIFIER_JSON_SCHEMA", "SPATIAL_LAYOUT_JSON_SCHEMA", "TEXT_RELATIONS_JSON_SCHEMA", @@ -237,7 +238,11 @@ "properties": { "asset_id": {"type": "string", "minLength": 1}, "is_arbitrary_layout": {"type": "boolean"}, - "reason": {"type": "string", "minLength": 1}, + "reason": { + "type": "string", + "minLength": 1, + "description": "Optional short explanation for debugging.", + }, }, "required": [ "asset_id", @@ -316,7 +321,11 @@ "properties": { "asset": {"type": "string", "minLength": 1}, "is_arbitrary_layout": {"type": "boolean"}, - "reason": {"type": "string", "minLength": 1}, + "reason": { + "type": "string", + "minLength": 1, + "description": "Optional explanation for this unresolved item.", + }, }, "required": ["asset", "is_arbitrary_layout", "reason"], }, @@ -372,3 +381,179 @@ }, "required": ["object_scales"], } + +SCENE_EDIT_INTENT_JSON_SCHEMA: dict[str, Any] = { + "title": "SceneEditIntentOutput", + "type": "object", + "additionalProperties": False, + "properties": { + "deleted_object_ids": { + "type": "array", + "description": ( + "Existing scene object ids that should be removed. This includes " + "objects removed by delete operations and objects replaced by new " + "generated objects." + ), + "items": {"type": "string", "minLength": 1}, + }, + "generated_objects": { + "type": "array", + "description": ( + "New objects that must be generated by the text-to-geometry " + "simready pipeline for add or replace operations." + ), + "items": { + "type": "object", + "additionalProperties": False, + "properties": { + "temp_id": { + "type": "string", + "minLength": 1, + "description": ( + "Temporary id used by this edit plan, such as " + "new_red_mug_0. It must not collide with existing ids." + ), + }, + "name": { + "type": "string", + "minLength": 1, + "description": ( + "Canonical English snake_case object name for " + "text-to-geometry." + ), + }, + "description": { + "type": "string", + "minLength": 20, + "maxLength": 220, + "description": ( + "Standalone appearance description used for " + "text-to-geometry simready generation." + ), + }, + "source_operation": { + "type": "string", + "enum": ["add", "replace"], + }, + }, + "required": [ + "temp_id", + "name", + "description", + "source_operation", + ], + }, + }, + "operations": { + "type": "array", + "items": { + "type": "object", + "additionalProperties": False, + "properties": { + "type": { + "type": "string", + "enum": ["delete", "replace", "add"], + }, + "target_object_id": { + "type": "string", + "description": ( + "Existing object id for delete/replace, or empty " + "string for pure add." + ), + }, + "new_object_temp_id": { + "type": "string", + "description": ( + "Generated object temp_id for add/replace, or empty " + "string for delete." + ), + }, + "placement": { + "type": "object", + "additionalProperties": False, + "properties": { + "type": { + "type": "string", + "enum": [ + "preserve_target", + "random", + "relative_to_object", + "grid", + ], + }, + "reference_object_id": { + "type": "string", + "description": ( + "Existing object id used as a spatial " + "reference, or empty string if unused." + ), + }, + "relation": { + "type": "string", + "enum": [ + "", + "left_of", + "right_of", + "front_of", + "back_of", + ], + }, + "grid": { + "type": "string", + "enum": [""] + GRID_VALUE_LIST, + }, + }, + "required": [ + "type", + "reference_object_id", + "relation", + "grid", + ], + }, + "reason": {"type": "string", "minLength": 1}, + "confidence": { + "type": "number", + "minimum": 0.0, + "maximum": 1.0, + }, + }, + "required": [ + "type", + "target_object_id", + "new_object_temp_id", + "placement", + "reason", + "confidence", + ], + }, + }, + "unresolved": { + "type": "array", + "items": { + "type": "object", + "additionalProperties": False, + "properties": { + "query": {"type": "string", "minLength": 1}, + "reason": {"type": "string", "minLength": 1}, + "candidate_object_ids": { + "type": "array", + "items": {"type": "string", "minLength": 1}, + }, + }, + "required": ["query", "reason", "candidate_object_ids"], + }, + }, + "reason": { + "type": "string", + "minLength": 1, + "description": "Brief overall explanation of the edit interpretation.", + }, + }, + "required": [ + "deleted_object_ids", + "generated_objects", + "operations", + "unresolved", + "reason", + ], +} diff --git a/embodichain/gen_sim/prompt2scene/workflows/__init__.py b/embodichain/gen_sim/prompt2scene/workflows/__init__.py index 393b0022b..63a6c449b 100644 --- a/embodichain/gen_sim/prompt2scene/workflows/__init__.py +++ b/embodichain/gen_sim/prompt2scene/workflows/__init__.py @@ -21,6 +21,7 @@ IMAGE_SEGMENTS_STEP, IMAGE_SPATIAL_RELATIONS_STEP, RAW_MODEL_OUTPUT_FILENAME, + SCENE_EDIT_STEP, SCENE_INTAKE_STEP, STEP_RESULT_FILENAME, TEXT_RELATIONS_STEP, @@ -33,6 +34,7 @@ "IMAGE_SEGMENTS_STEP", "IMAGE_SPATIAL_RELATIONS_STEP", "RAW_MODEL_OUTPUT_FILENAME", + "SCENE_EDIT_STEP", "SCENE_INTAKE_STEP", "STEP_RESULT_FILENAME", "TEXT_RELATIONS_STEP", diff --git a/embodichain/gen_sim/prompt2scene/workflows/artifact_writer.py b/embodichain/gen_sim/prompt2scene/workflows/artifact_writer.py index c535a9701..7660dfb09 100644 --- a/embodichain/gen_sim/prompt2scene/workflows/artifact_writer.py +++ b/embodichain/gen_sim/prompt2scene/workflows/artifact_writer.py @@ -25,6 +25,7 @@ IMAGE_SEGMENTS_STEP, IMAGE_SPATIAL_RELATIONS_STEP, RAW_MODEL_OUTPUT_FILENAME, + SCENE_EDIT_STEP, SCENE_INTAKE_STEP, STEP_RESULT_FILENAME, TEXT_RELATIONS_STEP, @@ -43,6 +44,7 @@ "IMAGE_SEGMENTS_STEP", "IMAGE_SPATIAL_RELATIONS_STEP", "RAW_MODEL_OUTPUT_FILENAME", + "SCENE_EDIT_STEP", "SCENE_INTAKE_STEP", "STEP_RESULT_FILENAME", "TEXT_RELATIONS_STEP", diff --git a/embodichain/gen_sim/prompt2scene/workflows/paths.py b/embodichain/gen_sim/prompt2scene/workflows/paths.py index 21243fa62..681586297 100644 --- a/embodichain/gen_sim/prompt2scene/workflows/paths.py +++ b/embodichain/gen_sim/prompt2scene/workflows/paths.py @@ -27,6 +27,7 @@ "IMAGE_SPATIAL_RELATIONS_STEP", "RAW_MODEL_OUTPUT_FILENAME", "SCENE_INTAKE_STEP", + "SCENE_EDIT_STEP", "STEP_RESULT_FILENAME", "TEXT_RELATIONS_STEP", "UNIFIED_SCENE_GEN_STEP", @@ -46,6 +47,7 @@ RAW_MODEL_OUTPUT_FILENAME = "raw_model_output.json" SCENE_INTAKE_STEP = "scene_intake" +SCENE_EDIT_STEP = "scene_edit" IMAGE_SEGMENTS_STEP = "image_segments" IMAGE_SPATIAL_RELATIONS_STEP = "image_spatial_relations" TEXT_RELATIONS_STEP = "text_relations" @@ -132,6 +134,10 @@ def __post_init__(self) -> None: def scene_intake_dir(self) -> Path: return self.output_root / SCENE_INTAKE_STEP + @property + def scene_edit_dir(self) -> Path: + return self.output_root / SCENE_EDIT_STEP + @property def image_segments_dir(self) -> Path: return self.output_root / IMAGE_SEGMENTS_STEP diff --git a/embodichain/gen_sim/prompt2scene/workflows/request.py b/embodichain/gen_sim/prompt2scene/workflows/request.py index 8cd01c30f..de4e91e00 100644 --- a/embodichain/gen_sim/prompt2scene/workflows/request.py +++ b/embodichain/gen_sim/prompt2scene/workflows/request.py @@ -30,6 +30,7 @@ class InputKind(str, Enum): IMAGE = "image" TEXT = "text" + EDIT = "edit" @dataclass(frozen=True) @@ -40,6 +41,7 @@ class Prompt2SceneInput: output_root: Path image_path: Path | None = None text: str | None = None + prompt: str | None = None @classmethod def from_cli_args( @@ -47,6 +49,7 @@ def from_cli_args( *, image_path: Path | None, text: str | None, + prompt: str | None, output_root: Path, ) -> "Prompt2SceneInput": """Create a prompt2scene input from CLI arguments. @@ -54,6 +57,7 @@ def from_cli_args( Args: image_path: Input image path, if image mode is selected. text: Text prompt, if text mode is selected. + prompt: Optional edit prompt. output_root: Directory where prompt2scene outputs are written. Returns: @@ -64,6 +68,12 @@ def from_cli_args( ValueError: If the image path is invalid or text input is empty. """ output_root = output_root.expanduser().resolve() + prompt_text = prompt.strip() if prompt is not None else None + if prompt_text == "": + prompt_text = None + + if image_path is not None and text is not None and text.strip(): + raise ValueError("Image and text inputs cannot be used at the same time.") if image_path is not None: image_path = image_path.expanduser().resolve() @@ -72,15 +82,21 @@ def from_cli_args( input_kind=InputKind.IMAGE, image_path=image_path, output_root=output_root, + prompt=prompt_text, ) - if text is None or not text.strip(): - raise ValueError("Text input must be non-empty.") + if text is not None and text.strip(): + return cls( + input_kind=InputKind.TEXT, + text=text.strip(), + output_root=output_root, + prompt=prompt_text, + ) return cls( - input_kind=InputKind.TEXT, - text=text.strip(), + input_kind=InputKind.EDIT, output_root=output_root, + prompt=cls._validate_edit_only_prompt(prompt_text, output_root), ) def to_manifest(self) -> dict[str, str]: @@ -92,11 +108,27 @@ def to_manifest(self) -> dict[str, str]: if self.input_kind == InputKind.IMAGE: image_path = self.image_path manifest["image_path"] = str(image_path) - else: + elif self.input_kind == InputKind.TEXT: text = self.text manifest["text"] = "" if text is None else text + if self.prompt is not None: + manifest["prompt"] = self.prompt return manifest + @staticmethod + def _validate_edit_only_prompt(prompt: str | None, output_root: Path) -> str: + if prompt is None: + raise ValueError( + "Provide --image, --text, or --prompt with an existing output_root." + ) + scene_state = output_root / "gym_export" / "scene_state" / "result.json" + if not scene_state.is_file(): + raise FileNotFoundError( + "Edit-only mode requires an existing scene state: " + f"{scene_state}" + ) + return prompt + @staticmethod def _validate_image_path(image_path: Path) -> None: """Validate supported image input paths.""" diff --git a/embodichain/gen_sim/prompt2scene/workflows/scene_edit/__init__.py b/embodichain/gen_sim/prompt2scene/workflows/scene_edit/__init__.py new file mode 100644 index 000000000..addab9d30 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/workflows/scene_edit/__init__.py @@ -0,0 +1,23 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from embodichain.gen_sim.prompt2scene.workflows.scene_edit.graph import ( + run_scene_edit, +) + +__all__ = ["run_scene_edit"] diff --git a/embodichain/gen_sim/prompt2scene/workflows/scene_edit/graph.py b/embodichain/gen_sim/prompt2scene/workflows/scene_edit/graph.py new file mode 100644 index 000000000..316909d2e --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/workflows/scene_edit/graph.py @@ -0,0 +1,108 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +import shutil +from typing import TYPE_CHECKING + +from embodichain.gen_sim.prompt2scene.utils.io import write_json +from embodichain.gen_sim.prompt2scene.workflows.artifact_writer import ( + write_step_result, +) +from embodichain.gen_sim.prompt2scene.workflows.paths import SCENE_EDIT_STEP +from embodichain.gen_sim.prompt2scene.workflows.scene_edit.nodes import ( + analyze_scene_edit_intent_node, + generate_edit_assets_node, + optimize_edit_layout_node, + update_scene_files_node, +) +from embodichain.gen_sim.prompt2scene.workflows.scene_edit.schema import ( + SceneEditRequest, + SceneEditResult, +) + +__all__ = ["run_scene_edit"] + +if TYPE_CHECKING: + from embodichain.gen_sim.prompt2scene.llms import OpenAICompatibleLLMCfg + + +def run_scene_edit( + request: SceneEditRequest, + *, + llm_cfg: OpenAICompatibleLLMCfg | None = None, +) -> SceneEditResult: + """Run the scene edit workflow.""" + output_root = request.output_root.expanduser().resolve() + scene_state_path = output_root / "gym_export" / "scene_state" / "result.json" + if not scene_state_path.is_file(): + raise FileNotFoundError( + "Scene edit requires an existing exported scene state: " + f"{scene_state_path}" + ) + output_dir = output_root / SCENE_EDIT_STEP + output_dir.mkdir(parents=True, exist_ok=True) + if llm_cfg is None: + raise ValueError("Scene edit requires an LLM config for intent analysis.") + from embodichain.gen_sim.prompt2scene.llms import build_chat_model + + llm = build_chat_model(llm_cfg) + + intent_analysis = analyze_scene_edit_intent_node( + request=request, + output_dir=output_dir, + llm=llm, + ) + generated_assets = generate_edit_assets_node( + intent_analysis=intent_analysis, + output_dir=output_dir, + llm=llm, + ) + layout_result = optimize_edit_layout_node( + intent_analysis=intent_analysis, + generated_assets=generated_assets, + output_dir=output_dir, + ) + file_updates = update_scene_files_node( + intent_analysis=intent_analysis, + generated_assets=generated_assets, + layout_result=layout_result, + output_dir=output_dir, + ) + + result = SceneEditResult( + status="ok" if file_updates.get("status") == "ok" else "partial", + prompt=request.prompt, + scene_state_path=scene_state_path, + reason=( + "Scene edit intent analysis, asset generation, layout optimization, " + "and gym_export file updates completed." + ), + steps={ + "intent_analysis": intent_analysis, + "generated_assets": generated_assets, + "layout_optimization": layout_result, + "file_updates": file_updates, + }, + ) + resolved_intent = intent_analysis.get("resolved_intent") + if isinstance(resolved_intent, dict): + write_json(output_dir / "resolved_intent.json", resolved_intent) + write_step_result(output_root, SCENE_EDIT_STEP, result.to_manifest()) + if request.cleanup_scene_edit_dir and output_dir.is_dir(): + shutil.rmtree(output_dir) + return result diff --git a/embodichain/gen_sim/prompt2scene/workflows/scene_edit/nodes.py b/embodichain/gen_sim/prompt2scene/workflows/scene_edit/nodes.py new file mode 100644 index 000000000..79740805a --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/workflows/scene_edit/nodes.py @@ -0,0 +1,326 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from pathlib import Path +from typing import Any + +from embodichain.gen_sim.prompt2scene.prompts.builders import ( + build_scene_edit_intent_messages, + build_text_metric_scale_messages, +) +from embodichain.gen_sim.prompt2scene.prompts.schemas import ( + IMAGE_METRIC_SCALE_JSON_SCHEMA, + SCENE_EDIT_INTENT_JSON_SCHEMA, +) +from embodichain.gen_sim.prompt2scene.agent_tools.tools.text_scene_metric_scale import ( + estimate_text_scene_metric_scale, +) +from embodichain.gen_sim.prompt2scene.utils import ( + log, + log_api_request_start, +) +from embodichain.gen_sim.prompt2scene.workflows.stage_errors import ( + format_attempt_error, + format_result_missing_error, +) +from embodichain.gen_sim.prompt2scene.workflows.scene_edit.schema import ( + SceneEditRequest, +) +from embodichain.gen_sim.prompt2scene.workflows.scene_edit.utils import ( + build_scene_edit_layout, + export_scene_edit_gym_state, + extract_current_grids, + extract_current_relations, + extract_scene_objects, + generate_scene_edit_object_assets, + load_json_object, + match_prompt_scene_objects, + resolve_scene_edit_intent, + scene_state_path, +) + +__all__ = [ + "analyze_scene_edit_intent_node", + "generate_edit_assets_node", + "optimize_edit_layout_node", + "update_scene_files_node", +] + + +def analyze_scene_edit_intent_node( + *, + request: SceneEditRequest, + output_dir: Path, + llm: Any, +) -> dict[str, Any]: + """Analyze existing scene state plus user prompt into structured edit intent.""" + state_path = scene_state_path(request.output_root) + if not state_path.is_file(): + raise FileNotFoundError( + "Scene edit requires an existing exported scene state: " + f"{state_path}" + ) + scene_state = load_json_object(state_path) + scene_objects = extract_scene_objects(scene_state) + current_relations = extract_current_relations( + output_root=request.output_root, + scene_state=scene_state, + ) + current_grids = extract_current_grids( + output_root=request.output_root, + scene_state=scene_state, + ) + messages = build_scene_edit_intent_messages( + prompt=request.prompt, + scene_objects=scene_objects, + current_relations=current_relations, + ) + from embodichain.gen_sim.prompt2scene.llms.llm_output import ( + StructuredModelCallError, + call_structured_json_model_step, + ) + + attempt_count = 0 + max_attempts = 3 + errors: list[str] = [] + raw_model_output: dict[str, Any] | None = None + retry_messages = list(messages) + persist_raw_model_output = False + while attempt_count < max_attempts: + attempt_count += 1 + try: + log_api_request_start( + step="scene_edit", + request="intent_analysis", + attempt=attempt_count, + ) + raw_model_output = call_structured_json_model_step( + llm=llm, + schema=SCENE_EDIT_INTENT_JSON_SCHEMA, + messages=retry_messages, + context="Scene edit intent", + attempt_count=attempt_count, + raw_output_writer=None, + ) + break + except StructuredModelCallError as exc: + error = format_attempt_error("Scene edit intent", attempt_count, exc) + errors.append(error) + log.log_warning(error) + persist_raw_model_output = True + retry_messages = list(messages) + [ + { + "role": "user", + "content": ( + "The previous JSON output failed schema validation. " + f"Fix this exact error and output the full JSON again: {exc}" + ), + } + ] + + if raw_model_output is None: + raise RuntimeError( + format_result_missing_error( + "Scene edit intent", + "SceneEditIntentOutput", + attempt_count=attempt_count, + last_error=errors[-1] if errors else None, + errors=errors, + ) + ) + + resolved_intent = resolve_scene_edit_intent( + intent=raw_model_output, + scene_objects=scene_objects, + current_relations=current_relations, + current_grids=current_grids, + ) + source_snapshots = scene_state.get("source_snapshots") or {} + object_matches = match_prompt_scene_objects( + prompt=request.prompt, + scene_state=scene_state, + ) + analysis = { + "status": "ok", + "node": "analyze_scene_edit_intent", + "prompt": request.prompt, + "scene_state_path": str(state_path), + "source_snapshots": source_snapshots, + "scene_summary": { + "object_count": len(scene_objects), + "objects": scene_objects, + }, + "current_relations": current_relations, + "current_grid_assignments": current_grids, + "prompt_object_matches": object_matches, + "llm_intent": raw_model_output, + "resolved_intent": resolved_intent, + } + if persist_raw_model_output: + analysis["debug"] = {"retry_errors": errors} + return analysis + + +def generate_edit_assets_node( + *, + intent_analysis: dict[str, Any], + output_dir: Path, + llm: Any | None = None, +) -> dict[str, Any]: + """Generate simready assets for add/replace objects in a scene edit.""" + intent = intent_analysis.get("resolved_intent") + if not isinstance(intent, dict): + intent = {} + generated_objects = intent.get("generated_objects") + if not isinstance(generated_objects, list): + generated_objects = [] + if not generated_objects: + return { + "status": "ok", + "node": "generate_edit_assets", + "input_intent_status": intent_analysis.get("status"), + "objects_to_generate": [], + "generated_assets": [], + "reason": "No new objects were requested by the edit intent.", + } + generation_result = generate_scene_edit_object_assets( + generated_objects=generated_objects, + output_root=output_dir.parent, + output_dir=output_dir, + ) + generated_assets = generation_result.get("generated_assets", []) + if isinstance(generated_assets, list) and generated_assets: + metric_prompt_objects = [ + { + "object_id": str(obj.get("id", "")), + "object_name": str(obj.get("name", "")), + "object_description": str(obj.get("description", "")), + } + for obj in generated_assets + ] + prompt_text = str(intent_analysis.get("prompt") or "") + metric_scale_result = estimate_text_scene_metric_scale( + object_results=generated_assets, + user_text=prompt_text, + messages=build_text_metric_scale_messages( + user_text=prompt_text, + objects_json=metric_prompt_objects, + ), + schema=IMAGE_METRIC_SCALE_JSON_SCHEMA, + output_dir=output_dir / "glb_gen" / "metric_scale", + output_root=output_dir.parent, + llm=llm, + step_name="scene_edit", + ) + else: + metric_scale_result = { + "status": "skipped", + "reason": "no_generated_assets", + "objects": [], + } + result = { + "status": generation_result.get("status", "partial"), + "node": "generate_edit_assets", + "input_intent_status": intent_analysis.get("status"), + "objects_to_generate": generated_objects, + "generated_assets": generated_assets, + "object_count": generation_result.get("object_count", 0), + "metric_scale": metric_scale_result, + "reason": ( + "Generated simready assets for scene-edit add/replace objects." + ), + } + return result + + +def optimize_edit_layout_node( + *, + intent_analysis: dict[str, Any], + generated_assets: dict[str, Any], + output_dir: Path, +) -> dict[str, Any]: + """Load the previous 2D layout and optimize an edited scene layout.""" + scene_state_value = intent_analysis.get("scene_state_path", "") + scene_state = load_json_object(Path(str(scene_state_value))) + resolved_intent = intent_analysis.get("resolved_intent") + if not isinstance(resolved_intent, dict): + resolved_intent = {} + generated_asset_items = generated_assets.get("generated_assets") + if not isinstance(generated_asset_items, list): + generated_asset_items = [] + layout = build_scene_edit_layout( + scene_state=scene_state, + resolved_intent=resolved_intent, + generated_assets=generated_asset_items, + output_root=output_dir.parent, + ) + return { + "status": layout.get("status", "ok"), + "node": "optimize_edit_layout", + "existing_scene_state_path": scene_state_value, + "generated_asset_count": len(generated_asset_items), + "deleted_object_ids": layout.get("deleted_object_ids", []), + "support_region": layout.get("support_region", {}), + "layout_updates": layout.get("layout_updates", []), + "optimization": layout.get("optimization", {}), + "reason": ( + "Loaded the previous scene_state 2D footprints, inherited replacement " + "object centers, computed generated-object XY sizes from simready GLBs, " + "and applied relation/grid-based local layout optimization." + ), + } + + +def update_scene_files_node( + *, + intent_analysis: dict[str, Any], + generated_assets: dict[str, Any], + layout_result: dict[str, Any], + output_dir: Path, +) -> dict[str, Any]: + """Update gym_export outputs so future scene edits read the edited scene.""" + scene_state_value = intent_analysis.get("scene_state_path", "") + scene_state = load_json_object(Path(str(scene_state_value))) + generated_asset_items = generated_assets.get("generated_assets") + if not isinstance(generated_asset_items, list): + generated_asset_items = [] + layout_updates = layout_result.get("layout_updates") + if not isinstance(layout_updates, list): + layout_updates = [] + export_result = export_scene_edit_gym_state( + output_root=output_dir.parent, + scene_state=scene_state, + generated_assets=generated_asset_items, + layout_updates=layout_updates, + output_dir=output_dir, + ) + return { + "status": export_result.get("status", "ok"), + "node": "update_scene_files", + "updated_files": export_result.get("updated_files", []), + "reason": ( + "Updated gym_export outputs from the edited scene layout, including " + "gym_config, scene_state/result.json, topdown_2d.png, and any new " + "simready mesh assets." + ), + "inputs": { + "intent_status": intent_analysis.get("status"), + "generated_assets_status": generated_assets.get("status"), + "layout_status": layout_result.get("status"), + }, + } diff --git a/embodichain/gen_sim/prompt2scene/workflows/scene_edit/schema.py b/embodichain/gen_sim/prompt2scene/workflows/scene_edit/schema.py new file mode 100644 index 000000000..6af4023eb --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/workflows/scene_edit/schema.py @@ -0,0 +1,52 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +__all__ = ["SceneEditRequest", "SceneEditResult"] + + +@dataclass(frozen=True) +class SceneEditRequest: + """Input for editing an existing generated scene.""" + + output_root: Path + prompt: str + cleanup_scene_edit_dir: bool = True + + +@dataclass(frozen=True) +class SceneEditResult: + """Structured result for the scene edit workflow skeleton.""" + + status: str + prompt: str + scene_state_path: Path + reason: str + steps: dict[str, Any] + + def to_manifest(self) -> dict[str, Any]: + return { + "status": self.status, + "prompt": self.prompt, + "scene_state_path": str(self.scene_state_path), + "reason": self.reason, + "steps": self.steps, + } diff --git a/embodichain/gen_sim/prompt2scene/workflows/scene_edit/utils.py b/embodichain/gen_sim/prompt2scene/workflows/scene_edit/utils.py new file mode 100644 index 000000000..1567a44f1 --- /dev/null +++ b/embodichain/gen_sim/prompt2scene/workflows/scene_edit/utils.py @@ -0,0 +1,1554 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +import json +import math +import re +import shutil +from pathlib import Path +from typing import Any + +import numpy as np + +from embodichain.gen_sim.prompt2scene.agent_tools.managers.layout_manager import ( + LayoutManager, +) +from embodichain.gen_sim.prompt2scene.agent_tools.tools.text_asset_generation import ( + generate_text_object_assets, +) +from embodichain.gen_sim.prompt2scene.utils.io import relative_path, write_json +from embodichain.gen_sim.prompt2scene.workflows.gym_export import ( + _glb_scale_to_sim, + _render_scene_state_topdown, +) +from embodichain.gen_sim.prompt2scene.workflows.paths import PipelinePaths +from embodichain.gen_sim.prompt2scene.agent_tools.tools.spatial_relations import ( + transitive_relation_closure, +) +from embodichain.gen_sim.prompt2scene.utils.log import log_info + +__all__ = [ + "build_scene_edit_layout", + "build_xy_footprint", + "clamp_center_to_support_region", + "compute_simready_glb_xy_size", + "extract_current_grids", + "extract_current_relations", + "extract_scene_edit_support_region", + "extract_scene_object_footprints", + "extract_scene_objects", + "generate_scene_edit_object_assets", + "export_scene_edit_gym_state", + "load_json_object", + "match_prompt_scene_objects", + "resolve_scene_edit_intent", + "resolve_scene_state_snapshot_path", + "scene_state_path", + "tokenize_text", + "validate_scene_edit_intent", +] + + +def scene_state_path(output_root: Path) -> Path: + return output_root / "gym_export" / "scene_state" / "result.json" + + +def load_json_object(path: Path) -> dict[str, Any]: + with path.open("r", encoding="utf-8") as f: + data = json.load(f) + if not isinstance(data, dict): + raise ValueError(f"Expected JSON object at {path}") + return data + + +def extract_scene_objects(scene_state: dict[str, Any]) -> list[dict[str, str]]: + """Return the minimal object view used by the edit-intent LLM.""" + objects: list[dict[str, str]] = [] + for obj in scene_state.get("objects", []) or []: + if not isinstance(obj, dict): + continue + object_id = str(obj.get("id", "")).strip() + if not object_id: + continue + objects.append( + { + "id": object_id, + "name": str(obj.get("name", "")).strip(), + "description": str(obj.get("description", "")).strip(), + } + ) + return objects + + +def extract_scene_object_footprints( + scene_state: dict[str, Any], +) -> dict[str, dict[str, Any]]: + """Return normalized object 2D footprints keyed by object id.""" + footprints: dict[str, dict[str, Any]] = {} + for obj in scene_state.get("objects", []) or []: + if not isinstance(obj, dict): + continue + object_id = str(obj.get("id", "")).strip() + footprint = obj.get("footprint_2d") + if not object_id or not isinstance(footprint, dict): + continue + center_xy = footprint.get("center_xy") + aabb_xy = footprint.get("aabb_xy") + size_xy = footprint.get("size_xy") + if not ( + isinstance(center_xy, list) + and len(center_xy) == 2 + and isinstance(aabb_xy, list) + and len(aabb_xy) == 2 + and all(isinstance(item, list) and len(item) == 2 for item in aabb_xy) + and isinstance(size_xy, list) + and len(size_xy) == 2 + ): + continue + footprints[object_id] = { + "unit": str(footprint.get("unit", "m")).strip() or "m", + "center_xy": [float(value) for value in center_xy], + "aabb_xy": [ + [float(value) for value in aabb_xy[0]], + [float(value) for value in aabb_xy[1]], + ], + "size_xy": [float(value) for value in size_xy], + } + return footprints + + +def extract_scene_edit_support_region(scene_state: dict[str, Any]) -> dict[str, Any]: + """Return the table support-region 2D manifest from the previous scene.""" + table = scene_state.get("table") + if not isinstance(table, dict): + return {"unit": "m", "center_xy": [], "aabb_xy": [], "size_xy": [], "corners_xy": []} + support_region = table.get("support_region_2d") + if not isinstance(support_region, dict): + return {"unit": "m", "center_xy": [], "aabb_xy": [], "size_xy": [], "corners_xy": []} + return support_region + + +def resolve_scene_state_snapshot_path( + *, + output_root: Path, + scene_state: dict[str, Any], + snapshot_name: str, +) -> Path | None: + """Resolve a snapshot path recorded in gym_export/scene_state/result.json.""" + source_snapshots = scene_state.get("source_snapshots") + if not isinstance(source_snapshots, dict): + return None + snapshot_value = source_snapshots.get(snapshot_name) + if not isinstance(snapshot_value, str) or not snapshot_value: + return None + snapshot_path = Path(snapshot_value) + if snapshot_path.is_absolute(): + return snapshot_path + return output_root / "gym_export" / snapshot_path + + +def extract_current_relations( + *, + output_root: Path, + scene_state: dict[str, Any], +) -> list[dict[str, str]]: + """Load canonical relations from the unified_scene snapshot if available.""" + snapshot_path = resolve_scene_state_snapshot_path( + output_root=output_root, + scene_state=scene_state, + snapshot_name="unified_scene", + ) + if snapshot_path is None or not snapshot_path.is_file(): + return [] + unified_scene = load_json_object(snapshot_path) + spatial = unified_scene.get("spatial") + if not isinstance(spatial, dict): + return [] + relations = spatial.get("relations") + if not isinstance(relations, list): + return [] + + normalized: list[dict[str, str]] = [] + for relation in relations: + if not isinstance(relation, dict): + continue + subject = str(relation.get("subject", "")).strip() + relation_name = str(relation.get("relation", "")).strip() + object_id = str(relation.get("object", "")).strip() + if not subject or not relation_name or not object_id: + continue + normalized.append( + { + "subject": subject, + "relation": relation_name, + "object": object_id, + "source": str(relation.get("source", "")).strip(), + } + ) + return normalized + + +def extract_current_grids( + *, + output_root: Path, + scene_state: dict[str, Any], +) -> dict[str, str]: + """Load object 9-grid assignments from the unified_scene snapshot.""" + snapshot_path = resolve_scene_state_snapshot_path( + output_root=output_root, + scene_state=scene_state, + snapshot_name="unified_scene", + ) + if snapshot_path is None or not snapshot_path.is_file(): + return {} + unified_scene = load_json_object(snapshot_path) + objects = unified_scene.get("objects") + if not isinstance(objects, list): + return {} + + grids: dict[str, str] = {} + for obj in objects: + if not isinstance(obj, dict): + continue + object_id = str(obj.get("id", "")).strip() + grid = str(obj.get("grid", "") or "").strip() + if object_id and grid: + grids[object_id] = grid + return grids + + +def resolve_scene_edit_intent( + *, + intent: dict[str, Any], + scene_objects: list[dict[str, str]], + current_relations: list[dict[str, str]], + current_grids: dict[str, str], +) -> dict[str, Any]: + """Resolve LLM edit operations into program-computed relations and grids.""" + validate_scene_edit_intent(intent=intent, scene_objects=scene_objects) + + operations = [op for op in intent.get("operations", []) if isinstance(op, dict)] + generated_objects = _normalize_generated_objects( + operations=operations, + generated_objects=[ + obj for obj in intent.get("generated_objects", []) if isinstance(obj, dict) + ], + ) + generated_ids = { + str(obj.get("temp_id", "")).strip() + for obj in generated_objects + if str(obj.get("temp_id", "")).strip() + } + deleted_ids = _string_set(intent.get("deleted_object_ids"), "deleted_object_ids") + + replacement_map: dict[str, str] = {} + replacement_inherits: set[str] = set() + for operation in operations: + if operation.get("type") != "replace": + continue + target_id = str(operation.get("target_object_id", "")).strip() + new_id = str(operation.get("new_object_temp_id", "")).strip() + if not target_id or not new_id: + continue + replacement_map[target_id] = new_id + placement = operation.get("placement") + placement_type = ( + str(placement.get("type", "")).strip() + if isinstance(placement, dict) + else "" + ) + if placement_type in {"", "preserve_target"}: + replacement_inherits.add(target_id) + + direct_relations: list[dict[str, str]] = [] + for relation in current_relations: + subject = str(relation.get("subject", "")).strip() + object_id = str(relation.get("object", "")).strip() + relation_name = str(relation.get("relation", "")).strip() + mapped_subject = _map_relation_endpoint( + object_id=subject, + deleted_ids=deleted_ids, + replacement_map=replacement_map, + replacement_inherits=replacement_inherits, + ) + mapped_object = _map_relation_endpoint( + object_id=object_id, + deleted_ids=deleted_ids, + replacement_map=replacement_map, + replacement_inherits=replacement_inherits, + ) + if mapped_subject is None or mapped_object is None: + continue + if mapped_subject == mapped_object: + continue + direct_relations.append( + { + "subject": mapped_subject, + "relation": relation_name, + "object": mapped_object, + "source": ( + "replacement_inherited" + if mapped_subject != subject or mapped_object != object_id + else "preserved" + ), + } + ) + + updated_grids: dict[str, str] = {} + for object_id, grid in current_grids.items(): + if object_id in deleted_ids: + replacement_id = replacement_map.get(object_id) + if replacement_id and object_id in replacement_inherits: + updated_grids[replacement_id] = grid + continue + updated_grids[object_id] = grid + + for operation in operations: + op_type = str(operation.get("type", "")).strip() + if op_type not in {"add", "replace"}: + continue + new_id = str(operation.get("new_object_temp_id", "")).strip() + if new_id not in generated_ids: + continue + placement = operation.get("placement") + if not isinstance(placement, dict): + continue + placement_type = str(placement.get("type", "")).strip() + if placement_type == "grid": + grid = str(placement.get("grid", "")).strip() + if grid: + updated_grids[new_id] = grid + elif placement_type == "relative_to_object": + reference_id = _map_reference_endpoint( + object_id=str(placement.get("reference_object_id", "")).strip(), + deleted_ids=deleted_ids, + replacement_map=replacement_map, + ) + relation = _placement_relation_to_canonical( + new_object_id=new_id, + relation=str(placement.get("relation", "")).strip(), + reference_object_id=reference_id or "", + ) + if relation is not None: + direct_relations.append({**relation, "source": "new_prompt"}) + + return { + "deleted_object_ids": sorted(deleted_ids), + "generated_objects": generated_objects, + "operations": operations, + "updated_relations": _close_relations_with_sources(direct_relations), + "updated_grid_assignments": dict(sorted(updated_grids.items())), + "unresolved": intent.get("unresolved", []), + "reason": intent.get("reason", ""), + } + + +def tokenize_text(value: str) -> set[str]: + return { + token + for token in re.split(r"[^a-zA-Z0-9]+", value.lower()) + if len(token) >= 2 + } + + +def match_prompt_scene_objects( + *, + prompt: str, + scene_state: dict[str, Any], +) -> list[dict[str, Any]]: + """Return rough object candidates mentioned by the edit prompt.""" + prompt_tokens = tokenize_text(prompt) + matches: list[dict[str, Any]] = [] + for obj in scene_state.get("objects", []) or []: + if not isinstance(obj, dict): + continue + text = " ".join( + str(obj.get(key, "")) + for key in ("id", "name", "description") + ) + object_tokens = tokenize_text(text.replace("_", " ")) + overlap = sorted(prompt_tokens & object_tokens) + if not overlap: + continue + score = len(overlap) / max(len(object_tokens), 1) + matches.append( + { + "id": obj.get("id", ""), + "name": obj.get("name", ""), + "description": obj.get("description", ""), + "matched_tokens": overlap, + "score": score, + "footprint_2d": obj.get("footprint_2d"), + } + ) + return sorted(matches, key=lambda item: float(item["score"]), reverse=True) + + +def generate_scene_edit_object_assets( + *, + generated_objects: list[dict[str, Any]], + output_root: Path, + output_dir: Path, +) -> dict[str, Any]: + """Generate simready assets for scene-edit add/replace objects.""" + image_gen_dir = output_dir / "image_gen" + glb_gen_dir = output_dir / "glb_gen" + debug_dir = output_dir / "debug" + image_gen_dir.mkdir(parents=True, exist_ok=True) + glb_gen_dir.mkdir(parents=True, exist_ok=True) + debug_dir.mkdir(parents=True, exist_ok=True) + + object_specs = [ + _scene_edit_object_spec(generated_object) + for generated_object in generated_objects + ] + log_info( + "scene_edit object asset generation started " + f"count={len(object_specs)} output_dir={output_dir}" + ) + object_results = generate_text_object_assets( + object_specs=object_specs, + image_gen_dir=image_gen_dir, + glb_gen_dir=glb_gen_dir, + debug_dir=debug_dir, + ) + normalized_results = [ + _normalize_generated_asset_result(result, output_root=output_root) + for result in object_results + ] + succeeded = sum( + str(result.get("status", "")).strip() == "ok" + for result in normalized_results + ) + status = "ok" if succeeded == len(normalized_results) else "partial" + if not normalized_results: + status = "ok" + log_info( + "scene_edit object asset generation completed " + f"succeeded={succeeded} failed={len(normalized_results) - succeeded}" + ) + return { + "status": status, + "object_count": len(normalized_results), + "generated_assets": normalized_results, + } + + +def compute_simready_glb_xy_size( + glb_path: Path, + *, + metric_scale: dict[str, Any] | None = None, +) -> list[float]: + """Compute the sim-plane XY footprint size from a Y-up GLB.""" + try: + import trimesh + except ImportError as exc: + raise RuntimeError("Scene edit layout requires trimesh.") from exc + + scene = trimesh.load(glb_path, force="scene") + if isinstance(scene, trimesh.Trimesh): + mesh = scene + else: + dumped = scene.dump(concatenate=True) + mesh = ( + dumped + if isinstance(dumped, trimesh.Trimesh) + else trimesh.util.concatenate( + [item for item in dumped if isinstance(item, trimesh.Trimesh)] + ) + ) + bounds = np.asarray(mesh.bounds, dtype=np.float64) + if bounds.shape != (2, 3): + raise ValueError(f"Invalid GLB bounds shape: {bounds.shape}") + size_x = float(bounds[1, 0] - bounds[0, 0]) + size_y = float(bounds[1, 2] - bounds[0, 2]) + scale_factor = 1.0 + if isinstance(metric_scale, dict): + try: + scale_factor = float(metric_scale.get("scale_factor", 1.0)) + except (TypeError, ValueError): + scale_factor = 1.0 + if not np.isfinite(scale_factor) or scale_factor <= 0.0: + scale_factor = 1.0 + return [ + max(size_x * scale_factor, 1.0e-4), + max(size_y * scale_factor, 1.0e-4), + ] + + +def build_xy_footprint( + *, + center_xy: list[float], + size_xy: list[float], +) -> dict[str, Any]: + """Build a footprint_2d record from center and size.""" + cx, cy = float(center_xy[0]), float(center_xy[1]) + sx, sy = max(float(size_xy[0]), 0.0), max(float(size_xy[1]), 0.0) + half_x = 0.5 * sx + half_y = 0.5 * sy + return { + "unit": "m", + "center_xy": [cx, cy], + "aabb_xy": [ + [cx - half_x, cy - half_y], + [cx + half_x, cy + half_y], + ], + "size_xy": [sx, sy], + } + + +def clamp_center_to_support_region( + *, + center_xy: list[float], + size_xy: list[float], + support_region: dict[str, Any], +) -> list[float]: + """Clamp an object centre so its AABB stays inside the support-region AABB.""" + aabb_xy = support_region.get("aabb_xy") + if not ( + isinstance(aabb_xy, list) + and len(aabb_xy) == 2 + and all(isinstance(item, list) and len(item) == 2 for item in aabb_xy) + ): + return [float(center_xy[0]), float(center_xy[1])] + min_xy = np.asarray(aabb_xy[0], dtype=np.float64) + max_xy = np.asarray(aabb_xy[1], dtype=np.float64) + half = 0.5 * np.asarray(size_xy, dtype=np.float64) + center = np.asarray(center_xy, dtype=np.float64) + lower = min_xy + half + upper = max_xy - half + clamped = center.copy() + for axis in range(2): + if lower[axis] <= upper[axis]: + clamped[axis] = min(max(center[axis], lower[axis]), upper[axis]) + else: + clamped[axis] = float(0.5 * (min_xy[axis] + max_xy[axis])) + return clamped.tolist() + + +def build_scene_edit_layout( + *, + scene_state: dict[str, Any], + resolved_intent: dict[str, Any], + generated_assets: list[dict[str, Any]], + output_root: Path, +) -> dict[str, Any]: + """Build an edited 2D layout on top of the previous scene state.""" + support_region = extract_scene_edit_support_region(scene_state) + old_footprints = extract_scene_object_footprints(scene_state) + old_objects_by_id = { + str(obj.get("id", "")).strip(): obj + for obj in scene_state.get("objects", []) or [] + if isinstance(obj, dict) and str(obj.get("id", "")).strip() + } + deleted_ids = { + str(object_id).strip() + for object_id in resolved_intent.get("deleted_object_ids", []) or [] + if str(object_id).strip() + } + operations = [ + op for op in resolved_intent.get("operations", []) or [] if isinstance(op, dict) + ] + updated_relations = [ + relation + for relation in resolved_intent.get("updated_relations", []) or [] + if isinstance(relation, dict) + ] + updated_grids = { + str(object_id).strip(): str(grid).strip() + for object_id, grid in (resolved_intent.get("updated_grid_assignments") or {}).items() + if str(object_id).strip() and str(grid).strip() + } + generated_asset_by_id = { + str(asset.get("id", "")).strip(): asset + for asset in generated_assets + if isinstance(asset, dict) + and str(asset.get("id", "")).strip() + and str(asset.get("status", "")).strip() == "ok" + } + + replacement_target_by_new_id: dict[str, str] = {} + placement_by_new_id: dict[str, dict[str, Any]] = {} + added_ids: list[str] = [] + replaced_ids: list[str] = [] + for operation in operations: + op_type = str(operation.get("type", "")).strip() + new_id = str(operation.get("new_object_temp_id", "")).strip() + if not new_id: + continue + placement = operation.get("placement") + if isinstance(placement, dict): + placement_by_new_id[new_id] = placement + if op_type == "replace": + target_id = str(operation.get("target_object_id", "")).strip() + if target_id: + replacement_target_by_new_id[new_id] = target_id + replaced_ids.append(new_id) + elif op_type == "add": + added_ids.append(new_id) + + final_items: dict[str, dict[str, Any]] = {} + for object_id, obj in old_objects_by_id.items(): + if object_id in deleted_ids: + continue + footprint = old_footprints.get(object_id) + if footprint is None: + continue + final_items[object_id] = { + "id": object_id, + "name": str(obj.get("name", "")).strip(), + "description": str(obj.get("description", "")).strip(), + "action": "keep", + "center_xy": list(footprint["center_xy"]), + "size_xy": list(footprint["size_xy"]), + "footprint_2d": footprint, + "source": "previous_scene", + } + + generated_ids = sorted(generated_asset_by_id) + if not generated_ids: + return { + "status": "ok", + "support_region": support_region, + "deleted_object_ids": sorted(deleted_ids), + "layout_updates": sorted(final_items.values(), key=lambda item: item["id"]), + "optimization": { + "method": "reuse_previous_scene", + "generated_object_count": 0, + }, + } + + xy_sizes = { + object_id: np.asarray( + compute_simready_glb_xy_size( + _resolve_generated_asset_path( + generated_asset_by_id[object_id], + output_root=output_root, + ), + metric_scale=generated_asset_by_id[object_id].get("metric_scale"), + ), + dtype=np.float64, + ) + for object_id in generated_ids + } + fixed_ids = set(replaced_ids) + + for object_id in replaced_ids: + if object_id not in generated_asset_by_id: + continue + target_id = replacement_target_by_new_id.get(object_id, "") + target_footprint = old_footprints.get(target_id) + if target_footprint is None: + continue + asset = generated_asset_by_id[object_id] + center_xy = clamp_center_to_support_region( + center_xy=list(target_footprint["center_xy"]), + size_xy=xy_sizes[object_id].tolist(), + support_region=support_region, + ) + final_items[object_id] = { + "id": object_id, + "name": str(asset.get("name", "")).strip(), + "description": str(asset.get("description", "")).strip(), + "action": "replace", + "replaces": target_id, + "center_xy": center_xy, + "size_xy": xy_sizes[object_id].tolist(), + "footprint_2d": build_xy_footprint( + center_xy=center_xy, + size_xy=xy_sizes[object_id].tolist(), + ), + "source": "generated_asset", + "simready_geometry_path": asset.get("simready_geometry_path") + or asset.get("mesh_path"), + } + + initialized_added_centers = _initialize_added_object_centers( + added_ids=[object_id for object_id in added_ids if object_id in generated_asset_by_id], + placement_by_new_id=placement_by_new_id, + updated_grids=updated_grids, + updated_relations=updated_relations, + stable_items=final_items, + support_region=support_region, + xy_sizes=xy_sizes, + ) + for object_id in added_ids: + if object_id not in generated_asset_by_id or object_id not in initialized_added_centers: + continue + asset = generated_asset_by_id[object_id] + center_xy = initialized_added_centers[object_id].tolist() + size_xy = xy_sizes[object_id].tolist() + final_items[object_id] = { + "id": object_id, + "name": str(asset.get("name", "")).strip(), + "description": str(asset.get("description", "")).strip(), + "action": "add", + "replaces": "", + "center_xy": center_xy, + "size_xy": size_xy, + "footprint_2d": build_xy_footprint(center_xy=center_xy, size_xy=size_xy), + "source": "generated_asset", + "simready_geometry_path": asset.get("simready_geometry_path") + or asset.get("mesh_path"), + } + + initial_centers_all = { + object_id: np.asarray(item["center_xy"], dtype=np.float64) + for object_id, item in final_items.items() + } + optimized_centers = {object_id: center.copy() for object_id, center in initial_centers_all.items()} + optimization_metadata: dict[str, Any] | None = None + all_object_ids = sorted(final_items) + if all_object_ids: + left_of_edges: list[tuple[str, str]] = [] + front_of_edges: list[tuple[str, str]] = [] + for relation in updated_relations: + subject = str(relation.get("subject", "")).strip() + object_id = str(relation.get("object", "")).strip() + relation_name = str(relation.get("relation", "")).strip() + if subject not in final_items or object_id not in final_items: + continue + if relation_name == "left_of": + left_of_edges.append((subject, object_id)) + elif relation_name == "front_of": + front_of_edges.append((subject, object_id)) + grid_spring_targets = { + object_id: initial_centers_all[object_id].copy() + for object_id in all_object_ids + if object_id in updated_grids and object_id in initial_centers_all + } + optimized_layout = LayoutManager.optimize_text_layout_slp( + object_ids=all_object_ids, + xy_sizes={object_id: np.asarray(final_items[object_id]["size_xy"], dtype=np.float64) for object_id in all_object_ids}, + initial_centers={object_id: initial_centers_all[object_id] for object_id in all_object_ids}, + left_of_edges=left_of_edges, + front_of_edges=front_of_edges, + grid_spring_targets=grid_spring_targets, + padding_ratio=0.08, + ) + all_optimized = { + object_id: np.asarray(center, dtype=np.float64) + for object_id, center in optimized_layout.get("centers", {}).items() + } + anchor_ids = [ + object_id + for object_id in all_object_ids + if object_id not in added_ids and object_id in all_optimized + ] + if not anchor_ids: + anchor_ids = [ + object_id + for object_id in all_object_ids + if object_id in all_optimized + ] + translation = np.zeros(2, dtype=np.float64) + if anchor_ids: + translation = np.mean( + np.vstack( + [ + initial_centers_all[object_id] - all_optimized[object_id] + for object_id in anchor_ids + ] + ), + axis=0, + ) + for object_id, center in all_optimized.items(): + optimized_centers[object_id] = np.asarray( + clamp_center_to_support_region( + center_xy=(center + translation).tolist(), + size_xy=final_items[object_id]["size_xy"], + support_region=support_region, + ), + dtype=np.float64, + ) + optimization_metadata = optimized_layout.get("metadata") + + for object_id, item in final_items.items(): + center_xy = optimized_centers[object_id].tolist() + size_xy = item["size_xy"] + item["center_xy"] = center_xy + item["footprint_2d"] = build_xy_footprint(center_xy=center_xy, size_xy=size_xy) + + return { + "status": "ok", + "support_region": support_region, + "deleted_object_ids": sorted(deleted_ids), + "layout_updates": sorted(final_items.values(), key=lambda item: item["id"]), + "optimization": { + "method": "delete_then_replace_then_add_initialize_then_optimize", + "generated_object_count": len(generated_ids), + "fixed_replacement_count": len(fixed_ids), + "replaced_object_count": len(replaced_ids), + "added_object_count": len(initialized_added_centers), + "initialized_added_object_count": len(initialized_added_centers), + "optimized_object_count": len(all_object_ids), + "added_layout_optimization": optimization_metadata, + }, + } + + +def export_scene_edit_gym_state( + *, + output_root: Path, + scene_state: dict[str, Any], + generated_assets: list[dict[str, Any]], + layout_updates: list[dict[str, Any]], + output_dir: Path, +) -> dict[str, Any]: + """Update gym_export files from scene-edit layout results.""" + paths = PipelinePaths(output_root) + gym_config_path = paths.gym_config + if not gym_config_path.is_file(): + raise FileNotFoundError(f"gym_config.json not found: {gym_config_path}") + gym_config = load_json_object(gym_config_path) + rigid_objects = gym_config.get("rigid_object") + if not isinstance(rigid_objects, list): + raise ValueError("gym_config rigid_object must be a list.") + + scene_objects = scene_state.get("objects") + if not isinstance(scene_objects, list): + raise ValueError("scene_state objects must be a list.") + + rigid_by_id = { + str(item.get("uid", "")).strip(): item + for item in rigid_objects + if isinstance(item, dict) and str(item.get("uid", "")).strip() + } + scene_by_id = { + str(item.get("id", "")).strip(): item + for item in scene_objects + if isinstance(item, dict) and str(item.get("id", "")).strip() + } + generated_asset_by_id = { + str(item.get("id", "")).strip(): item + for item in generated_assets + if isinstance(item, dict) and str(item.get("id", "")).strip() + } + layout_by_id = { + str(item.get("id", "")).strip(): item + for item in layout_updates + if isinstance(item, dict) and str(item.get("id", "")).strip() + } + + scene_state_dir = output_root / "gym_export" / "scene_state" + mesh_assets_dir = output_root / "gym_export" / "mesh_assets" + scene_state_dir.mkdir(parents=True, exist_ok=True) + mesh_assets_dir.mkdir(parents=True, exist_ok=True) + + existing_table_height = _infer_scene_edit_table_height( + rigid_by_id=rigid_by_id, + layout_by_id=layout_by_id, + ) + + updated_rigid_objects: list[dict[str, Any]] = [] + updated_scene_objects: list[dict[str, Any]] = [] + updated_files: list[str] = [] + + for object_id, layout_item in layout_by_id.items(): + action = str(layout_item.get("action", "")).strip() + center_xy = layout_item.get("center_xy") + size_xy = layout_item.get("size_xy") + if not ( + isinstance(center_xy, list) + and len(center_xy) == 2 + and isinstance(size_xy, list) + and len(size_xy) == 2 + ): + continue + old_rigid = rigid_by_id.get(object_id) + old_scene_obj = scene_by_id.get(object_id) + if action == "keep" and old_rigid is None: + continue + + if action == "keep": + updated_rigid = _update_existing_rigid_object( + object_id=object_id, + rigid_object=old_rigid, + old_scene_object=old_scene_obj, + layout_item=layout_item, + ) + else: + generated_asset = generated_asset_by_id.get(object_id) + if generated_asset is None: + raise ValueError(f"Missing generated asset for edited object: {object_id}") + updated_rigid = _build_generated_rigid_object( + object_id=object_id, + layout_item=layout_item, + generated_asset=generated_asset, + output_root=output_root, + mesh_assets_dir=mesh_assets_dir, + table_height=existing_table_height, + ) + shape = updated_rigid.get("shape") + if isinstance(shape, dict): + updated_files.append(str(shape.get("fpath", ""))) + + updated_rigid_objects.append(updated_rigid) + updated_scene_objects.append( + _build_scene_state_object( + object_id=object_id, + layout_item=layout_item, + rigid_object=updated_rigid, + output_root=output_root, + ) + ) + + gym_config["rigid_object"] = updated_rigid_objects + write_json(gym_config_path, gym_config) + updated_files.append(relative_path(gym_config_path, output_root)) + + topdown_path = scene_state_dir / "topdown_2d.png" + _render_scene_state_topdown( + support_region=extract_scene_edit_support_region(scene_state), + objects=updated_scene_objects, + output_path=topdown_path, + ) + updated_files.append(relative_path(topdown_path, output_root)) + + state_payload = dict(scene_state) + state_payload["gym_config_path"] = str(gym_config_path.relative_to(output_root / "gym_export")) + state_payload["topdown_2d_plot_path"] = str(topdown_path.relative_to(output_root / "gym_export")) + state_payload["objects"] = updated_scene_objects + source_snapshots = dict(scene_state.get("source_snapshots") or {}) + layout_snapshot_path = scene_state_dir / "scene_edit_layout.json" + write_json( + layout_snapshot_path, + {"layout_updates": layout_updates}, + ) + source_snapshots["scene_edit_layout"] = str( + layout_snapshot_path.relative_to(output_root / "gym_export") + ) + state_payload["source_snapshots"] = source_snapshots + scene_state_result_path = scene_state_dir / "result.json" + write_json(scene_state_result_path, state_payload) + updated_files.append(relative_path(scene_state_result_path, output_root)) + updated_files.append(relative_path(layout_snapshot_path, output_root)) + + return { + "status": "ok", + "updated_files": sorted(set(updated_files)), + "object_count": len(updated_scene_objects), + "gym_config_path": str(gym_config_path), + "scene_state_path": str(scene_state_result_path), + } + + +def validate_scene_edit_intent( + *, + intent: dict[str, Any], + scene_objects: list[dict[str, str]], +) -> None: + """Validate that an edit intent only references legal object ids.""" + existing_ids = {obj["id"] for obj in scene_objects if obj.get("id")} + deleted_ids = _string_set(intent.get("deleted_object_ids"), "deleted_object_ids") + unknown_deleted = sorted(deleted_ids - existing_ids) + if unknown_deleted: + raise ValueError( + "Scene edit intent deleted unknown object ids: " + f"{unknown_deleted}" + ) + + generated_objects = intent.get("generated_objects") + if not isinstance(generated_objects, list): + raise ValueError("Scene edit intent generated_objects must be a list.") + generated_ids: set[str] = set() + for generated in generated_objects: + if not isinstance(generated, dict): + raise ValueError("Scene edit intent generated_objects entries must be objects.") + temp_id = str(generated.get("temp_id", "")).strip() + if not temp_id: + raise ValueError("Scene edit intent generated object has empty temp_id.") + if temp_id in existing_ids: + raise ValueError( + f"Scene edit generated temp_id collides with existing id: {temp_id}" + ) + if temp_id in generated_ids: + raise ValueError( + f"Scene edit generated temp_id is duplicated: {temp_id}" + ) + generated_ids.add(temp_id) + + operations = intent.get("operations") + if not isinstance(operations, list): + raise ValueError("Scene edit intent operations must be a list.") + for operation in operations: + if not isinstance(operation, dict): + raise ValueError("Scene edit intent operation entries must be objects.") + op_type = str(operation.get("type", "")).strip() + target_id = str(operation.get("target_object_id", "")).strip() + new_temp_id = str(operation.get("new_object_temp_id", "")).strip() + if op_type in {"delete", "replace"} and target_id not in existing_ids: + raise ValueError( + f"Scene edit {op_type} operation targets unknown object id: " + f"{target_id}" + ) + if op_type == "delete" and target_id not in deleted_ids: + raise ValueError( + f"Scene edit delete target is missing from deleted_object_ids: " + f"{target_id}" + ) + if op_type == "replace": + if target_id not in deleted_ids: + raise ValueError( + "Scene edit replace target is missing from deleted_object_ids: " + f"{target_id}" + ) + if new_temp_id not in generated_ids: + raise ValueError( + "Scene edit replace operation references unknown generated " + f"temp_id: {new_temp_id}" + ) + if op_type == "add" and new_temp_id not in generated_ids: + raise ValueError( + f"Scene edit add operation references unknown generated temp_id: {new_temp_id}" + ) + placement = operation.get("placement") + if isinstance(placement, dict): + reference_id = str(placement.get("reference_object_id", "")).strip() + if reference_id and reference_id not in existing_ids: + raise ValueError( + "Scene edit placement references unknown object id: " + f"{reference_id}" + ) + + +def _scene_edit_object_spec(generated_object: dict[str, Any]) -> dict[str, Any]: + temp_id = str(generated_object.get("temp_id", "")).strip() + name = str(generated_object.get("name", "")).strip() + class_candidates = [name] if name else [] + return { + "id": temp_id, + "name": name, + "description": str(generated_object.get("description", "")).strip(), + "class_candidate": class_candidates, + } + + +def _normalize_generated_asset_result( + result: dict[str, Any], + *, + output_root: Path, +) -> dict[str, Any]: + normalized = dict(result) + for key in ( + "image_path", + "raw_geometry_path", + "mesh_path", + "simready_geometry_path", + ): + value = normalized.get(key) + if value: + normalized[key] = relative_path(value, output_root) + return normalized + + +def _map_relation_endpoint( + *, + object_id: str, + deleted_ids: set[str], + replacement_map: dict[str, str], + replacement_inherits: set[str], +) -> str | None: + if object_id in deleted_ids: + replacement_id = replacement_map.get(object_id) + if replacement_id and object_id in replacement_inherits: + return replacement_id + return None + return object_id + + +def _normalize_generated_objects( + *, + operations: list[dict[str, Any]], + generated_objects: list[dict[str, Any]], +) -> list[dict[str, Any]]: + operation_type_by_temp_id: dict[str, str] = {} + for operation in operations: + new_temp_id = str(operation.get("new_object_temp_id", "")).strip() + op_type = str(operation.get("type", "")).strip() + if new_temp_id and op_type in {"add", "replace"}: + operation_type_by_temp_id[new_temp_id] = op_type + + normalized: list[dict[str, Any]] = [] + for generated in generated_objects: + temp_id = str(generated.get("temp_id", "")).strip() + if not temp_id: + continue + source_operation = str(generated.get("source_operation", "")).strip() + normalized.append( + { + **generated, + "source_operation": ( + source_operation or operation_type_by_temp_id.get(temp_id, "add") + ), + } + ) + return normalized + + +def _map_reference_endpoint( + *, + object_id: str, + deleted_ids: set[str], + replacement_map: dict[str, str], +) -> str | None: + if object_id in replacement_map: + return replacement_map[object_id] + if object_id in deleted_ids: + return None + return object_id + + +def _placement_relation_to_canonical( + *, + new_object_id: str, + relation: str, + reference_object_id: str, +) -> dict[str, str] | None: + if not new_object_id or not reference_object_id: + return None + if relation == "left_of": + return { + "subject": new_object_id, + "relation": "left_of", + "object": reference_object_id, + } + if relation == "right_of": + return { + "subject": reference_object_id, + "relation": "left_of", + "object": new_object_id, + } + if relation == "front_of": + return { + "subject": new_object_id, + "relation": "front_of", + "object": reference_object_id, + } + if relation == "back_of": + return { + "subject": reference_object_id, + "relation": "front_of", + "object": new_object_id, + } + return None + + +def _close_relations_with_sources( + direct_relations: list[dict[str, str]], +) -> list[dict[str, str]]: + if not direct_relations: + return [] + source_by_edge = { + ( + str(relation.get("subject", "")).strip(), + str(relation.get("relation", "")).strip(), + str(relation.get("object", "")).strip(), + ): str(relation.get("source", "")).strip() + for relation in direct_relations + } + closed = transitive_relation_closure(direct_relations) + result: list[dict[str, str]] = [] + for relation in closed: + key = ( + relation["subject"], + relation["relation"], + relation["object"], + ) + source = source_by_edge.get(key) + result.append( + { + "subject": relation["subject"], + "relation": relation["relation"], + "object": relation["object"], + "source": source or "transitive_closure", + } + ) + return result + + +def _string_set(value: Any, context: str) -> set[str]: + if not isinstance(value, list): + raise ValueError(f"Scene edit intent {context} must be a list.") + result: set[str] = set() + for item in value: + text = str(item).strip() + if not text: + raise ValueError(f"Scene edit intent {context} contains an empty id.") + result.add(text) + return result + + +def _resolve_generated_asset_path(asset: dict[str, Any], *, output_root: Path) -> Path: + value = asset.get("simready_geometry_path") or asset.get("mesh_path") + path = Path(str(value or "")).expanduser() + if path.is_absolute(): + return path.resolve() + return (output_root / path).resolve() + + +def _infer_scene_edit_table_height( + *, + rigid_by_id: dict[str, dict[str, Any]], + layout_by_id: dict[str, dict[str, Any]], +) -> float: + heights: list[float] = [] + for object_id, rigid_object in rigid_by_id.items(): + layout_item = layout_by_id.get(object_id) + if layout_item is None: + continue + init_pos = rigid_object.get("init_pos") + if not isinstance(init_pos, list) or len(init_pos) != 3: + continue + try: + heights.append(float(init_pos[2])) + except (TypeError, ValueError): + continue + if heights: + return float(np.median(np.asarray(heights, dtype=np.float64))) + return 0.0 + + +def _update_existing_rigid_object( + *, + object_id: str, + rigid_object: dict[str, Any] | None, + old_scene_object: dict[str, Any] | None, + layout_item: dict[str, Any], +) -> dict[str, Any]: + if rigid_object is None: + raise ValueError(f"Missing rigid_object for existing scene object: {object_id}") + updated = json.loads(json.dumps(rigid_object)) + old_center = _scene_edit_center_xy(old_scene_object) + new_center = np.asarray(layout_item.get("center_xy", []), dtype=np.float64) + init_pos = list(updated.get("init_pos") or [0.0, 0.0, 0.0]) + if old_center is not None and new_center.shape == (2,): + delta = new_center - old_center + init_pos[0] = float(init_pos[0]) + float(delta[0]) + init_pos[1] = float(init_pos[1]) + float(delta[1]) + updated["init_pos"] = [float(value) for value in init_pos] + updated["description"] = str(layout_item.get("description", "")).strip() or str( + updated.get("description", "") + ).strip() + return updated + + +def _build_generated_rigid_object( + *, + object_id: str, + layout_item: dict[str, Any], + generated_asset: dict[str, Any], + output_root: Path, + mesh_assets_dir: Path, + table_height: float, +) -> dict[str, Any]: + simready_path = _resolve_generated_asset_path(generated_asset, output_root=output_root) + if not simready_path.is_file(): + raise FileNotFoundError(f"Generated simready GLB not found: {simready_path}") + safe_name = object_id.replace("interact_", "").strip("_") or "object" + object_dir = mesh_assets_dir / safe_name / object_id + object_dir.mkdir(parents=True, exist_ok=True) + object_dst = object_dir / f"{object_id}.glb" + shutil.copy2(simready_path, object_dst) + + metric_scale = generated_asset.get("metric_scale") + scale_factor = 1.0 + if isinstance(metric_scale, dict): + try: + scale_factor = float(metric_scale.get("scale_factor", 1.0)) + except (TypeError, ValueError): + scale_factor = 1.0 + if not np.isfinite(scale_factor) or scale_factor <= 0.0: + scale_factor = 1.0 + body_scale = _glb_scale_to_sim([scale_factor, scale_factor, scale_factor]) + init_rot = [0.0, 0.0, 0.0] + target_center = np.asarray(layout_item.get("center_xy", []), dtype=np.float64) + if target_center.shape != (2,): + raise ValueError(f"Missing center_xy for generated object: {object_id}") + init_pos = [ + float(target_center[0]), + float(target_center[1]), + float(table_height), + ] + return { + "uid": object_id, + "description": str(layout_item.get("description", "")).strip(), + "shape": { + "shape_type": "Mesh", + "fpath": str(object_dst.relative_to(output_root / "gym_export")), + "compute_uv": False, + }, + "attrs": { + "mass": 0.01, + "contact_offset": 0.003, + "rest_offset": 0.001, + "restitution": 0.01, + "max_depenetration_velocity": 10.0, + "min_position_iters": 32, + "min_velocity_iters": 8, + }, + "body_type": "dynamic", + "init_pos": init_pos, + "init_rot": init_rot, + "body_scale": body_scale, + "max_convex_hull_num": 32, + } + + +def _build_scene_state_object( + *, + object_id: str, + layout_item: dict[str, Any], + rigid_object: dict[str, Any], + output_root: Path, +) -> dict[str, Any]: + init_rot = [float(value) for value in rigid_object.get("init_rot") or [0.0, 0.0, 0.0]] + body_scale = [float(value) for value in rigid_object.get("body_scale") or [1.0, 1.0, 1.0]] + init_pos = [float(value) for value in rigid_object.get("init_pos") or [0.0, 0.0, 0.0]] + footprint_2d = layout_item.get("footprint_2d") or build_xy_footprint( + center_xy=list(layout_item.get("center_xy", [0.0, 0.0])), + size_xy=list(layout_item.get("size_xy", [0.0, 0.0])), + ) + return { + "id": object_id, + "name": str(layout_item.get("name", "")).strip() or object_id, + "role": "interact", + "description": str(layout_item.get("description", "")).strip(), + "init_pos": init_pos, + "init_rot": init_rot, + "body_scale": body_scale, + "footprint_2d": footprint_2d, + } + + +def _scene_edit_center_xy(scene_object: dict[str, Any] | None) -> np.ndarray | None: + if not isinstance(scene_object, dict): + return None + footprint = scene_object.get("footprint_2d") + if not isinstance(footprint, dict): + return None + center_xy = footprint.get("center_xy") + if not isinstance(center_xy, list) or len(center_xy) != 2: + return None + return np.asarray(center_xy, dtype=np.float64) + + +def _compute_anchor_targets( + *, + generated_ids: list[str], + replacement_target_by_new_id: dict[str, str], + placement_by_new_id: dict[str, dict[str, Any]], + updated_grids: dict[str, str], + old_footprints: dict[str, dict[str, Any]], + support_region: dict[str, Any], + xy_sizes: dict[str, np.ndarray], +) -> dict[str, np.ndarray]: + targets: dict[str, np.ndarray] = {} + unresolved = set(generated_ids) + for _ in range(max(len(generated_ids), 1) * 2): + progressed = False + for object_id in list(unresolved): + replacement_target = replacement_target_by_new_id.get(object_id) + if replacement_target: + target_footprint = old_footprints.get(replacement_target) + if target_footprint is None: + continue + targets[object_id] = np.asarray( + target_footprint["center_xy"], + dtype=np.float64, + ) + unresolved.remove(object_id) + progressed = True + continue + + placement = placement_by_new_id.get(object_id, {}) + placement_type = str(placement.get("type", "")).strip() + if placement_type == "relative_to_object": + reference_id = str(placement.get("reference_object_id", "")).strip() + relation = str(placement.get("relation", "")).strip() + reference_center = targets.get(reference_id) + reference_size = xy_sizes.get(reference_id) + if reference_center is None: + reference = old_footprints.get(reference_id) + if reference is not None: + reference_center = np.asarray( + reference["center_xy"], + dtype=np.float64, + ) + reference_size = np.asarray( + reference["size_xy"], + dtype=np.float64, + ) + if reference_center is not None and reference_size is not None: + targets[object_id] = _offset_center_by_relation( + reference_center=reference_center, + reference_size=reference_size, + object_size=xy_sizes[object_id], + relation=relation, + ) + unresolved.remove(object_id) + progressed = True + continue + + grid_name = updated_grids.get(object_id) + if grid_name: + targets[object_id] = _support_region_grid_center( + support_region=support_region, + grid_name=grid_name, + ) + unresolved.remove(object_id) + progressed = True + continue + if not progressed: + break + return targets + + +def _initialize_added_object_centers( + *, + added_ids: list[str], + placement_by_new_id: dict[str, dict[str, Any]], + updated_grids: dict[str, str], + updated_relations: list[dict[str, Any]], + stable_items: dict[str, dict[str, Any]], + support_region: dict[str, Any], + xy_sizes: dict[str, np.ndarray], +) -> dict[str, np.ndarray]: + if not added_ids: + return {} + stable_footprints = { + object_id: { + "center_xy": item["center_xy"], + "size_xy": item["size_xy"], + } + for object_id, item in stable_items.items() + } + anchor_targets = _compute_anchor_targets( + generated_ids=added_ids, + replacement_target_by_new_id={}, + placement_by_new_id=placement_by_new_id, + updated_grids=updated_grids, + old_footprints=stable_footprints, + support_region=support_region, + xy_sizes=xy_sizes, + ) + active_relations = [ + relation + for relation in updated_relations + if str(relation.get("subject", "")).strip() in added_ids + and str(relation.get("object", "")).strip() in added_ids + ] + table_constraints = [ + {"asset": object_id, "grid": updated_grids[object_id]} + for object_id in added_ids + if object_id in updated_grids + ] + layout_seed = LayoutManager.layout_text_objects_grid( + object_ids=added_ids, + xy_sizes={object_id: xy_sizes[object_id] for object_id in added_ids}, + spatial_relations=active_relations, + table_constraints=table_constraints, + ) + relative_centers = { + object_id: np.asarray(layout_seed["centers"][object_id], dtype=np.float64) + for object_id in added_ids + } + translation = np.zeros(2, dtype=np.float64) + if anchor_targets: + translation = np.mean( + np.vstack( + [ + anchor_targets[object_id] - relative_centers[object_id] + for object_id in added_ids + if object_id in anchor_targets + ] + ), + axis=0, + ) + return { + object_id: np.asarray( + clamp_center_to_support_region( + center_xy=(relative_centers[object_id] + translation).tolist(), + size_xy=xy_sizes[object_id].tolist(), + support_region=support_region, + ), + dtype=np.float64, + ) + for object_id in added_ids + } + + +def _support_region_grid_center( + *, + support_region: dict[str, Any], + grid_name: str, +) -> np.ndarray: + aabb_xy = support_region.get("aabb_xy") + if not ( + isinstance(aabb_xy, list) + and len(aabb_xy) == 2 + and all(isinstance(item, list) and len(item) == 2 for item in aabb_xy) + ): + return np.zeros(2, dtype=np.float64) + min_xy = np.asarray(aabb_xy[0], dtype=np.float64) + max_xy = np.asarray(aabb_xy[1], dtype=np.float64) + size = max_xy - min_xy + cell = size / 3.0 + grid_to_rc = { + "left_front": (0, 0), + "center_front": (1, 0), + "right_front": (2, 0), + "left_center": (0, 1), + "center": (1, 1), + "right_center": (2, 1), + "left_back": (0, 2), + "center_back": (1, 2), + "right_back": (2, 2), + "front": (1, 0), + "back": (1, 2), + "left": (0, 1), + "right": (2, 1), + } + col, row = grid_to_rc.get(grid_name, (1, 1)) + center_x = min_xy[0] + (col + 0.5) * cell[0] + center_y = max_xy[1] - (row + 0.5) * cell[1] + return np.asarray([center_x, center_y], dtype=np.float64) + + +def _offset_center_by_relation( + *, + reference_center: np.ndarray, + reference_size: np.ndarray, + object_size: np.ndarray, + relation: str, + padding: float = 0.02, +) -> np.ndarray: + gap_x = 0.5 * (reference_size[0] + object_size[0]) + padding + gap_y = 0.5 * (reference_size[1] + object_size[1]) + padding + offset = np.zeros(2, dtype=np.float64) + if relation == "left_of": + offset[0] = -gap_x + elif relation == "right_of": + offset[0] = gap_x + elif relation == "front_of": + offset[1] = -gap_y + elif relation in {"back_of", "behind"}: + offset[1] = gap_y + else: + offset = np.asarray([gap_x, 0.0], dtype=np.float64) + return reference_center + offset