diff --git a/.gitignore b/.gitignore index 7405b2797..f84991adf 100644 --- a/.gitignore +++ b/.gitignore @@ -190,6 +190,8 @@ embodichain/agents/policy/runs/* *.pth outputs test_configs/* +embodichain_data/ +scripts/tutorials/atomic_action/*.glb wandb/ *.mp4 @@ -200,4 +202,4 @@ wandb/ embodichain/VERSION # benchmark results -scripts/benchmark/rl/reports/* \ No newline at end of file +scripts/benchmark/rl/reports/* diff --git a/docs/source/_static/atomic_actions/coordinated_pickment.gif b/docs/source/_static/atomic_actions/coordinated_pickment.gif new file mode 100644 index 000000000..cbacf8494 Binary files /dev/null and b/docs/source/_static/atomic_actions/coordinated_pickment.gif differ diff --git a/docs/source/api_reference/embodichain/embodichain.lab.sim.atomic_actions.rst b/docs/source/api_reference/embodichain/embodichain.lab.sim.atomic_actions.rst index 90077698f..783f76170 100644 --- a/docs/source/api_reference/embodichain/embodichain.lab.sim.atomic_actions.rst +++ b/docs/source/api_reference/embodichain/embodichain.lab.sim.atomic_actions.rst @@ -16,8 +16,10 @@ embodichain.lab.sim.atomic_actions NamedJointPositionTarget GraspTarget HeldObjectPoseTarget + CoordinatedPickmentTarget Target HeldObjectState + CoordinatedHeldObjectState WorldState ActionResult ActionCfg @@ -35,6 +37,8 @@ embodichain.lab.sim.atomic_actions Place PressCfg Press + CoordinatedPickmentCfg + CoordinatedPickment AtomicActionEngine .. currentmodule:: embodichain.lab.sim.atomic_actions @@ -78,12 +82,20 @@ Core :members: :show-inheritance: +.. autoclass:: CoordinatedPickmentTarget + :members: + :show-inheritance: + .. autodata:: Target .. autoclass:: HeldObjectState :members: :show-inheritance: +.. autoclass:: CoordinatedHeldObjectState + :members: + :show-inheritance: + .. autoclass:: WorldState :members: :show-inheritance: @@ -164,6 +176,15 @@ Actions :members: :show-inheritance: +.. autoclass:: CoordinatedPickmentCfg + :members: + :exclude-members: __init__, copy, replace, to_dict + :show-inheritance: + +.. autoclass:: CoordinatedPickment + :members: + :show-inheritance: + Engine & Registry ----------------- diff --git a/docs/source/overview/sim/atomic_actions/builtin_actions.md b/docs/source/overview/sim/atomic_actions/builtin_actions.md index 4ba7589ea..52859c988 100644 --- a/docs/source/overview/sim/atomic_actions/builtin_actions.md +++ b/docs/source/overview/sim/atomic_actions/builtin_actions.md @@ -15,6 +15,7 @@ The following actions are available out of the box: | `MoveHeldObject` | Single | `HeldObjectPoseTarget` — held-object pose | Move held object while keeping gripper closed | MoveHeldObject | | `Place` | Single | `EndEffectorPoseTarget` — EEF release pose | Lower → open gripper → retract | Place | | `Press` | Single | `EndEffectorPoseTarget` — EEF press pose | Close gripper → press down → return | Press | +| `CoordinatedPickment` | Dual | `CoordinatedPickmentTarget` — shared-object pose | Approach both ends → close both grippers → lift → move object | CoordinatedPickment | --- @@ -145,3 +146,28 @@ threaded into it. of shape `(4, 4)` or `(n_envs, 4, 4)`. ![Press demo](../../../_static/atomic_actions/press.gif) + +--- + +## `CoordinatedPickment` + +Dual-arm grasp motion for one shared object. Both arms move to object-relative +grasp poses, close both grippers, lift the object, and move it to an object pose +while keeping both grippers closed. On success, the returned `WorldState` carries +`coordinated_held_object` (`CoordinatedHeldObjectState`) and leaves +`held_object` as `None`. + +| Config field | Default | Description | +|---|---|---| +| `control_part` | `"dual_arm"` | Combined arm control part | +| `left_arm_control_part` / `right_arm_control_part` | `"left_arm"` / `"right_arm"` | Arm control parts for each grasp | +| `left_hand_control_part` / `right_hand_control_part` | `"left_hand"` / `"right_hand"` | Hand control parts for each gripper | +| `pre_grasp_distance` | `0.10` | Distance to back away from each grasp TCP | +| `lift_height` | `0.08` | World-Z lift distance before moving to the target pose | +| `object_motion_keyframes` | `6` | Sparse object-pose IK keyframes for synchronized motion | +| `sample_interval` | `120` | Total waypoints across all phases | + +**Target:** `CoordinatedPickmentTarget(...)` with a target object pose, object +semantics, and left/right object-to-EEF transforms. + +![CoordinatedPickment demo](../../../_static/atomic_actions/coordinated_pickment.gif) diff --git a/embodichain/data/assets/demo_assets.py b/embodichain/data/assets/demo_assets.py index 9b00a654a..255002ab6 100644 --- a/embodichain/data/assets/demo_assets.py +++ b/embodichain/data/assets/demo_assets.py @@ -14,6 +14,8 @@ # limitations under the License. # ---------------------------------------------------------------------------- +from __future__ import annotations + import open3d as o3d import os @@ -49,3 +51,20 @@ def __init__(self, data_root: str = None): prefix = type(self).__name__ path = EMBODICHAIN_DEFAULT_DATA_ROOT if data_root is None else data_root super().__init__(prefix, data_descriptor, path) + + +class CoordinatedPlacementAndPickment(EmbodiChainDataset): + """Dataset class for coordinated placement and pickment tutorial meshes.""" + + def __init__(self, data_root: str = None): + data_descriptor = o3d.data.DataDescriptor( + os.path.join( + EMBODICHAIN_DOWNLOAD_PREFIX, + demo_assets, + "coordinated_placement_and_pickment.zip", + ), + "297c10b386a4d7a8ccb68926d69425e9", + ) + prefix = type(self).__name__ + path = EMBODICHAIN_DEFAULT_DATA_ROOT if data_root is None else data_root + super().__init__(prefix, data_descriptor, path) diff --git a/embodichain/lab/sim/atomic_actions/__init__.py b/embodichain/lab/sim/atomic_actions/__init__.py index c2e7a2cc4..279e3230f 100644 --- a/embodichain/lab/sim/atomic_actions/__init__.py +++ b/embodichain/lab/sim/atomic_actions/__init__.py @@ -31,6 +31,8 @@ ActionCfg, ActionResult, AtomicAction, + CoordinatedHeldObjectState, + CoordinatedPickmentTarget, GraspTarget, HeldObjectState, HeldObjectPoseTarget, @@ -42,6 +44,8 @@ WorldState, ) from .actions import ( + CoordinatedPickment, + CoordinatedPickmentCfg, MoveEndEffector, MoveJoints, MoveHeldObject, @@ -69,6 +73,8 @@ "AntipodalAffordance", "InteractionPoints", "ObjectSemantics", + "CoordinatedHeldObjectState", + "CoordinatedPickmentTarget", "HeldObjectState", "HeldObjectPoseTarget", "JointPositionTarget", @@ -81,12 +87,14 @@ "ActionCfg", "AtomicAction", # Action implementations + "CoordinatedPickment", "MoveEndEffector", "MoveJoints", "MoveHeldObject", "PickUp", "Place", "Press", + "CoordinatedPickmentCfg", "MoveEndEffectorCfg", "MoveJointsCfg", "MoveHeldObjectCfg", diff --git a/embodichain/lab/sim/atomic_actions/actions.py b/embodichain/lab/sim/atomic_actions/actions.py index b525f595e..9436b2706 100644 --- a/embodichain/lab/sim/atomic_actions/actions.py +++ b/embodichain/lab/sim/atomic_actions/actions.py @@ -32,13 +32,15 @@ from embodichain.lab.sim.planners import PlanState, MoveType from embodichain.utils import configclass, logger -from embodichain.utils.math import pose_inv +from embodichain.utils.math import matrix_from_quat, pose_inv, quat_from_matrix from .affordance import AntipodalAffordance from .core import ( ActionCfg, ActionResult, AtomicAction, + CoordinatedHeldObjectState, + CoordinatedPickmentTarget, GraspTarget, HeldObjectState, HeldObjectPoseTarget, @@ -145,6 +147,57 @@ class PlaceCfg(ActionCfg): """Height (m) to retract the end-effector after opening the gripper.""" +@configclass +class CoordinatedPickmentCfg(ActionCfg): + name: str = "coordinated_pickment" + """Name of the action, used for identification and logging.""" + + control_part: str = "dual_arm" + """Combined control part containing left and right arm joints.""" + + left_arm_control_part: str = "left_arm" + """Left arm control part used to grasp one end of the object.""" + + right_arm_control_part: str = "right_arm" + """Right arm control part used to grasp the other end of the object.""" + + left_hand_control_part: str = "left_hand" + """Hand attached to the left arm.""" + + right_hand_control_part: str = "right_hand" + """Hand attached to the right arm.""" + + left_hand_open_qpos: torch.Tensor | None = None + """Left hand qpos for the open state.""" + + left_hand_close_qpos: torch.Tensor | None = None + """Left hand qpos for the closed state.""" + + right_hand_open_qpos: torch.Tensor | None = None + """Right hand qpos for the open state.""" + + right_hand_close_qpos: torch.Tensor | None = None + """Right hand qpos for the closed state.""" + + object_motion_keyframes: int = 6 + """Number of object-pose keyframes solved by IK before joint-space interpolation.""" + + pre_grasp_distance: float = 0.10 + """World distance to retreat from each grasp pose along negative TCP z.""" + + lift_height: float = 0.08 + """World-Z lift distance before moving to the object target pose.""" + + sample_interval: int = 120 + """Number of waypoints for the full coordinated pickment trajectory.""" + + hand_interp_steps: int = 10 + """Number of waypoints used for the simultaneous hand close phase.""" + + hold_steps: int = 4 + """Number of waypoints to hold the final object target pose.""" + + @configclass class PressCfg(ActionCfg): name: str = "press" @@ -195,6 +248,320 @@ def _arm_qpos_from_state( # ============================================================================= +class _DualArmHelpers: + """Shared trajectory helpers for dual-arm coordinated actions.""" + + def _init_dual_arm_parts( + self, + *, + first_arm_control_part: str, + second_arm_control_part: str, + first_hand_control_part: str, + second_hand_control_part: str, + ) -> None: + self.builder = TrajectoryBuilder(self.motion_generator) + self.n_envs = self.robot.get_qpos().shape[0] + self.robot_dof = self.robot.dof + self.dual_arm_joint_ids = self.robot.get_joint_ids(name=self.cfg.control_part) + self.first_arm_joint_ids = self.robot.get_joint_ids(name=first_arm_control_part) + self.second_arm_joint_ids = self.robot.get_joint_ids( + name=second_arm_control_part + ) + self.first_hand_joint_ids = self.robot.get_joint_ids( + name=first_hand_control_part + ) + self.second_hand_joint_ids = self.robot.get_joint_ids( + name=second_hand_control_part + ) + self.first_arm_dof = len(self.first_arm_joint_ids) + self.second_arm_dof = len(self.second_arm_joint_ids) + self.dual_arm_dof = len(self.dual_arm_joint_ids) + self.first_hand_dof = len(self.first_hand_joint_ids) + self.second_hand_dof = len(self.second_hand_joint_ids) + self._dual_id_to_col = { + joint_id: col for col, joint_id in enumerate(self.dual_arm_joint_ids) + } + self._first_arm_cols = self._lookup_joint_columns( + self.first_arm_joint_ids, + self._dual_id_to_col, + first_arm_control_part, + ) + self._second_arm_cols = self._lookup_joint_columns( + self.second_arm_joint_ids, + self._dual_id_to_col, + second_arm_control_part, + ) + + @staticmethod + def _lookup_joint_columns( + joint_ids: list[int], + joint_id_to_col: dict[int, int], + control_part: str, + ) -> list[int]: + """Map global joint ids into local trajectory columns.""" + missing = [ + joint_id for joint_id in joint_ids if joint_id not in joint_id_to_col + ] + if missing: + logger.log_error( + f"Joints {missing} from '{control_part}' are not included in " + "the configured dual-arm control part.", + ValueError, + ) + return [joint_id_to_col[joint_id] for joint_id in joint_ids] + + def _fail(self, state: WorldState) -> ActionResult: + return ActionResult( + success=False, + trajectory=torch.empty( + (self.n_envs, 0, self.robot_dof), + dtype=torch.float32, + device=self.device, + ), + next_state=state, + ) + + def _expand_qpos(self, qpos: torch.Tensor, dof: int, name: str) -> torch.Tensor: + """Resolve qpos to batched shape ``(n_envs, dof)``.""" + qpos = qpos.to(device=self.device, dtype=torch.float32) + if qpos.shape == (dof,): + return qpos.unsqueeze(0).repeat(self.n_envs, 1) + if qpos.shape == (self.n_envs, dof): + return qpos + logger.log_error( + f"{name} must have shape ({dof},) or " + f"({self.n_envs}, {dof}), but got {qpos.shape}", + ValueError, + ) + raise AssertionError("unreachable") + + def _resolve_pose(self, pose: torch.Tensor, name: str) -> torch.Tensor: + """Resolve a pose tensor into batched shape ``(n_envs, 4, 4)``.""" + pose = pose.to(device=self.device, dtype=torch.float32) + if pose.shape == (4, 4): + pose = pose.unsqueeze(0).repeat(self.n_envs, 1, 1) + if pose.shape != (self.n_envs, 4, 4): + logger.log_error( + f"{name} must have shape (4, 4) or " + f"({self.n_envs}, 4, 4), but got {pose.shape}", + ValueError, + ) + return pose + + def _resolve_dual_arm_start( + self, + state: WorldState, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Resolve full-robot state into the two arm qpos tensors.""" + dual_start = state.last_qpos[:, self.dual_arm_joint_ids].to( + device=self.device, dtype=torch.float32 + ) + return ( + dual_start[:, self._first_arm_cols], + dual_start[:, self._second_arm_cols], + ) + + def _plan_named_arm_trajectory( + self, + control_part: str, + start_qpos: torch.Tensor, + target_poses: torch.Tensor, + n_waypoints: int, + ) -> tuple[bool, torch.Tensor]: + """Plan a batched arm trajectory for a named control part.""" + n_state = target_poses.shape[1] + arm_dof = start_qpos.shape[-1] + trajectory = torch.zeros( + (self.n_envs, n_state, arm_dof), + dtype=torch.float32, + device=self.device, + ) + qpos_seed = start_qpos + for i in range(n_state): + is_success, qpos = self.robot.compute_ik( + pose=target_poses[:, i], + name=control_part, + joint_seed=qpos_seed, + ) + if not self.builder.all_envs_success(is_success): + logger.log_warning( + f"Failed to compute IK for {control_part} target state {i}." + ) + return False, trajectory + trajectory[:, i] = qpos + qpos_seed = qpos + + trajectory = torch.cat([start_qpos.unsqueeze(1), trajectory], dim=1) + return True, ( + self.builder.plan_joint_traj( + trajectory[:, 0], + trajectory[:, -1], + n_waypoints, + ) + if n_state == 1 + else self._interpolate_keyframe_qpos(trajectory, n_waypoints) + ) + + def _compose_dual_arm_trajectory( + self, + first_arm_traj: torch.Tensor, + second_arm_traj: torch.Tensor, + ) -> torch.Tensor: + """Compose first and second arm trajectories in dual-arm joint order.""" + n_waypoints = first_arm_traj.shape[1] + dual_arm_traj = torch.zeros( + (self.n_envs, n_waypoints, self.dual_arm_dof), + dtype=torch.float32, + device=self.device, + ) + dual_arm_traj[:, :, self._first_arm_cols] = first_arm_traj + dual_arm_traj[:, :, self._second_arm_cols] = second_arm_traj + return dual_arm_traj + + def _assemble_phase( + self, + state: WorldState, + first_arm_traj: torch.Tensor, + second_arm_traj: torch.Tensor, + first_hand_traj: torch.Tensor, + second_hand_traj: torch.Tensor, + ) -> torch.Tensor: + """Embed dual-arm and hand trajectories into full robot DoF order.""" + n_waypoints = first_arm_traj.shape[1] + full = torch.empty( + (self.n_envs, n_waypoints, self.robot_dof), + dtype=torch.float32, + device=self.device, + ) + full[:, :, :] = state.last_qpos.to(self.device).unsqueeze(1) + full[:, :, self.dual_arm_joint_ids] = self._compose_dual_arm_trajectory( + first_arm_traj, second_arm_traj + ) + full[:, :, self.first_hand_joint_ids] = first_hand_traj + full[:, :, self.second_hand_joint_ids] = second_hand_traj + return full + + @staticmethod + def _repeat_qpos(qpos: torch.Tensor, n_waypoints: int) -> torch.Tensor: + """Repeat batched qpos across waypoints.""" + return qpos.unsqueeze(1).repeat(1, n_waypoints, 1) + + def _interpolate_qpos( + self, + start_qpos: torch.Tensor, + end_qpos: torch.Tensor, + n_waypoints: int, + ) -> torch.Tensor: + """Interpolate batched qpos between two states.""" + weights = torch.linspace( + 0.0, + 1.0, + steps=n_waypoints, + device=self.device, + dtype=start_qpos.dtype, + ) + return torch.lerp( + start_qpos.unsqueeze(1), + end_qpos.unsqueeze(1), + weights[None, :, None], + ) + + def _interpolate_keyframe_qpos( + self, keyframe_qpos: torch.Tensor, n_waypoints: int + ) -> torch.Tensor: + """Interpolate a sequence of qpos keyframes into ``n_waypoints`` samples.""" + n_keyframes = keyframe_qpos.shape[1] + keyframe_indices = ( + torch.linspace( + 0, + n_waypoints - 1, + steps=n_keyframes, + device=self.device, + ) + .round() + .to(dtype=torch.long) + ) + return self._interpolate_qpos_keyframes( + keyframe_qpos, keyframe_indices, n_waypoints + ) + + def _interpolate_qpos_keyframes( + self, + keyframe_qpos: torch.Tensor, + keyframe_indices: torch.Tensor, + n_waypoints: int, + ) -> torch.Tensor: + """Interpolate qpos keyframes using shared waypoint indices.""" + trajectory = torch.zeros( + (self.n_envs, n_waypoints, keyframe_qpos.shape[-1]), + dtype=torch.float32, + device=self.device, + ) + for segment_idx in range(len(keyframe_indices) - 1): + start_idx = int(keyframe_indices[segment_idx].item()) + end_idx = int(keyframe_indices[segment_idx + 1].item()) + n_segment = end_idx - start_idx + 1 + weights = torch.linspace( + 0.0, + 1.0, + steps=n_segment, + dtype=keyframe_qpos.dtype, + device=self.device, + ) + segment = torch.lerp( + keyframe_qpos[:, segment_idx : segment_idx + 1], + keyframe_qpos[:, segment_idx + 1 : segment_idx + 2], + weights[None, :, None], + ) + trajectory[:, start_idx : end_idx + 1] = segment + return trajectory + + def _interpolate_object_pose( + self, + start_pose: torch.Tensor, + end_pose: torch.Tensor, + n_waypoints: int, + *, + include_orientation: bool, + ) -> torch.Tensor: + """Interpolate object translation and optionally orientation.""" + weights = torch.linspace( + 0.0, + 1.0, + steps=n_waypoints, + device=self.device, + dtype=start_pose.dtype, + ) + poses = start_pose.unsqueeze(1).repeat(1, n_waypoints, 1, 1) + poses[:, :, :3, 3] = torch.lerp( + start_pose[:, None, :3, 3], + end_pose[:, None, :3, 3], + weights[None, :, None], + ) + if not include_orientation: + return poses + + start_quat = quat_from_matrix(start_pose[:, :3, :3]) + end_quat = quat_from_matrix(end_pose[:, :3, :3]) + quat_dot = torch.sum(start_quat * end_quat, dim=-1, keepdim=True) + end_quat = torch.where(quat_dot < 0.0, -end_quat, end_quat) + quat = torch.lerp( + start_quat.unsqueeze(1), + end_quat.unsqueeze(1), + weights[None, :, None], + ) + quat = quat / torch.linalg.norm(quat, dim=-1, keepdim=True).clamp_min(1e-8) + poses[:, :, :3, :3] = matrix_from_quat(quat.reshape(-1, 4)).reshape( + self.n_envs, n_waypoints, 3, 3 + ) + return poses + + +# ============================================================================= +# MoveEndEffector +# ============================================================================= + + class MoveEndEffector(AtomicAction): """Plan a free-space end-effector move to a target pose. @@ -243,7 +610,9 @@ def execute(self, target: EndEffectorPoseTarget, state: WorldState) -> ActionRes success=True, trajectory=full, next_state=WorldState( - last_qpos=full[:, -1, :].clone(), held_object=state.held_object + last_qpos=full[:, -1, :].clone(), + held_object=state.held_object, + coordinated_held_object=state.coordinated_held_object, ), ) @@ -347,7 +716,9 @@ def execute( success=True, trajectory=full, next_state=WorldState( - last_qpos=full[:, -1, :].clone(), held_object=state.held_object + last_qpos=full[:, -1, :].clone(), + held_object=state.held_object, + coordinated_held_object=state.coordinated_held_object, ), ) @@ -520,7 +891,11 @@ def execute(self, target: GraspTarget, state: WorldState) -> ActionResult: return ActionResult( success=True, trajectory=full, - next_state=WorldState(last_qpos=full[:, -1, :].clone(), held_object=held), + next_state=WorldState( + last_qpos=full[:, -1, :].clone(), + held_object=held, + coordinated_held_object=state.coordinated_held_object, + ), ) def _fail(self, state: WorldState) -> ActionResult: @@ -656,6 +1031,7 @@ def execute(self, target: HeldObjectPoseTarget, state: WorldState) -> ActionResu next_state=WorldState( last_qpos=full[:, -1, :].clone(), held_object=state.held_object, + coordinated_held_object=state.coordinated_held_object, ), ) @@ -797,7 +1173,11 @@ def execute(self, target: EndEffectorPoseTarget, state: WorldState) -> ActionRes return ActionResult( success=True, trajectory=full, - next_state=WorldState(last_qpos=full[:, -1, :].clone(), held_object=None), + next_state=WorldState( + last_qpos=full[:, -1, :].clone(), + held_object=None, + coordinated_held_object=state.coordinated_held_object, + ), ) def _fail(self, state: WorldState) -> ActionResult: @@ -812,6 +1192,411 @@ def _fail(self, state: WorldState) -> ActionResult: ) +# ============================================================================= +# CoordinatedPickment +# ============================================================================= + + +class CoordinatedPickment(AtomicAction): + """Pick and move a single object pinched by two hands.""" + + TargetType: ClassVar[type] = CoordinatedPickmentTarget + + _assemble_phase = _DualArmHelpers._assemble_phase + _compose_dual_arm_trajectory = _DualArmHelpers._compose_dual_arm_trajectory + _expand_qpos = _DualArmHelpers._expand_qpos + _fail = _DualArmHelpers._fail + _init_dual_arm_parts = _DualArmHelpers._init_dual_arm_parts + _interpolate_keyframe_qpos = _DualArmHelpers._interpolate_keyframe_qpos + _interpolate_object_pose = _DualArmHelpers._interpolate_object_pose + _interpolate_qpos = _DualArmHelpers._interpolate_qpos + _interpolate_qpos_keyframes = _DualArmHelpers._interpolate_qpos_keyframes + _lookup_joint_columns = staticmethod(_DualArmHelpers._lookup_joint_columns) + _plan_named_arm_trajectory = _DualArmHelpers._plan_named_arm_trajectory + _repeat_qpos = staticmethod(_DualArmHelpers._repeat_qpos) + _resolve_dual_arm_start = _DualArmHelpers._resolve_dual_arm_start + _resolve_pose = _DualArmHelpers._resolve_pose + + def __init__( + self, + motion_generator, + cfg: CoordinatedPickmentCfg | None = None, + ) -> None: + super().__init__(motion_generator, cfg or CoordinatedPickmentCfg()) + self._init_dual_arm_parts( + first_arm_control_part=self.cfg.left_arm_control_part, + second_arm_control_part=self.cfg.right_arm_control_part, + first_hand_control_part=self.cfg.left_hand_control_part, + second_hand_control_part=self.cfg.right_hand_control_part, + ) + self.left_arm_joint_ids = self.first_arm_joint_ids + self.right_arm_joint_ids = self.second_arm_joint_ids + self.left_hand_joint_ids = self.first_hand_joint_ids + self.right_hand_joint_ids = self.second_hand_joint_ids + self.left_arm_dof = self.first_arm_dof + self.right_arm_dof = self.second_arm_dof + self.left_hand_dof = self.first_hand_dof + self.right_hand_dof = self.second_hand_dof + + self._validate_hand_qpos_cfg() + self.left_hand_open_qpos = self._expand_qpos( + self.cfg.left_hand_open_qpos, self.left_hand_dof, "left_hand_open_qpos" + ) + self.left_hand_close_qpos = self._expand_qpos( + self.cfg.left_hand_close_qpos, self.left_hand_dof, "left_hand_close_qpos" + ) + self.right_hand_open_qpos = self._expand_qpos( + self.cfg.right_hand_open_qpos, self.right_hand_dof, "right_hand_open_qpos" + ) + self.right_hand_close_qpos = self._expand_qpos( + self.cfg.right_hand_close_qpos, + self.right_hand_dof, + "right_hand_close_qpos", + ) + + def _validate_hand_qpos_cfg(self) -> None: + """Ensure all hand state tensors are provided.""" + for name in ( + "left_hand_open_qpos", + "left_hand_close_qpos", + "right_hand_open_qpos", + "right_hand_close_qpos", + ): + if getattr(self.cfg, name) is None: + logger.log_error(f"{name} must be specified in CoordinatedPickmentCfg") + + def _resolve_object_initial_pose( + self, target: CoordinatedPickmentTarget + ) -> torch.Tensor: + """Resolve the current pose of the object being grasped.""" + if target.object_initial_pose is not None: + return self._resolve_pose(target.object_initial_pose, "object_initial_pose") + if target.object_semantics.entity is None: + logger.log_error( + "CoordinatedPickmentTarget requires object_initial_pose when " + "object_semantics.entity is not provided.", + ValueError, + ) + return self._resolve_pose( + target.object_semantics.entity.get_local_pose(to_matrix=True), + "object_initial_pose", + ) + + def _resolve_target( + self, + target: CoordinatedPickmentTarget, + ) -> tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + CoordinatedHeldObjectState, + ]: + """Resolve an object-centric pickment target into left/right TCP poses.""" + object_initial_pose = self._resolve_object_initial_pose(target) + object_target_pose = self._resolve_pose( + target.object_target_pose, "object_target_pose" + ) + left_object_to_eef = self._resolve_pose( + target.left_object_to_eef, "left_object_to_eef" + ) + right_object_to_eef = self._resolve_pose( + target.right_object_to_eef, "right_object_to_eef" + ) + + left_grasp_xpos = torch.bmm(object_initial_pose, left_object_to_eef) + right_grasp_xpos = torch.bmm(object_initial_pose, right_object_to_eef) + left_target_xpos = torch.bmm(object_target_pose, left_object_to_eef) + right_target_xpos = torch.bmm(object_target_pose, right_object_to_eef) + held_state = CoordinatedHeldObjectState( + semantics=target.object_semantics, + left_object_to_eef=left_object_to_eef, + right_object_to_eef=right_object_to_eef, + left_grasp_xpos=left_grasp_xpos, + right_grasp_xpos=right_grasp_xpos, + ) + return ( + object_initial_pose, + object_target_pose, + left_grasp_xpos, + right_grasp_xpos, + left_target_xpos, + right_target_xpos, + held_state, + ) + + def _compute_segment_lengths(self) -> dict[str, int]: + """Compute waypoint counts for coordinated pickment phases.""" + n_close = max(2, self.cfg.hand_interp_steps) + n_hold = max(0, self.cfg.hold_steps) + n_motion = self.cfg.sample_interval - n_close - n_hold + n_approach = n_motion // 3 + n_lift = n_motion // 3 + n_move = n_motion - n_approach - n_lift + if min(n_approach, n_lift, n_move) < 2: + logger.log_error( + "Not enough waypoints for coordinated pickment. Please increase " + "sample_interval or decrease hand_interp_steps/hold_steps.", + ValueError, + ) + return { + "approach": n_approach, + "close": n_close, + "lift": n_lift, + "move": n_move, + "hold": n_hold, + } + + def get_segment_lengths(self) -> dict[str, int]: + """Return waypoint counts for the coordinated pickment phase sequence.""" + return self._compute_segment_lengths() + + def _compute_pre_grasp_xpos(self, grasp_xpos: torch.Tensor) -> torch.Tensor: + """Compute pre-grasp poses by backing away along each TCP z axis.""" + grasp_z = grasp_xpos[:, :3, 2] + return self.builder.apply_local_offset( + grasp_xpos, -grasp_z * self.cfg.pre_grasp_distance + ) + + def _select_motion_keyframe_indices(self, n_waypoints: int) -> torch.Tensor: + """Select sparse object motion keyframes for IK, including endpoints.""" + n_keyframes = min(max(2, int(self.cfg.object_motion_keyframes)), n_waypoints) + return ( + torch.linspace( + 0, + n_waypoints - 1, + steps=n_keyframes, + device=self.device, + ) + .round() + .to(dtype=torch.long) + ) + + def _plan_synchronized_object_motion( + self, + left_start_qpos: torch.Tensor, + right_start_qpos: torch.Tensor, + object_pose_traj: torch.Tensor, + left_object_to_eef: torch.Tensor, + right_object_to_eef: torch.Tensor, + ) -> tuple[bool, torch.Tensor, torch.Tensor]: + """Plan both arms from the same sparse object-pose trajectory.""" + n_waypoints = object_pose_traj.shape[1] + keyframe_indices = self._select_motion_keyframe_indices(n_waypoints) + left_traj = torch.zeros( + (self.n_envs, len(keyframe_indices), left_start_qpos.shape[-1]), + dtype=torch.float32, + device=self.device, + ) + right_traj = torch.zeros( + (self.n_envs, len(keyframe_indices), right_start_qpos.shape[-1]), + dtype=torch.float32, + device=self.device, + ) + left_qpos_seed = left_start_qpos + right_qpos_seed = right_start_qpos + for keyframe_col, waypoint_idx in enumerate(keyframe_indices.tolist()): + left_xpos = torch.bmm(object_pose_traj[:, waypoint_idx], left_object_to_eef) + right_xpos = torch.bmm( + object_pose_traj[:, waypoint_idx], right_object_to_eef + ) + left_success, left_qpos = self.robot.compute_ik( + pose=left_xpos, + name=self.cfg.left_arm_control_part, + joint_seed=left_qpos_seed, + ) + right_success, right_qpos = self.robot.compute_ik( + pose=right_xpos, + name=self.cfg.right_arm_control_part, + joint_seed=right_qpos_seed, + ) + if not self.builder.all_envs_success(left_success): + logger.log_warning( + f"Failed to compute IK for {self.cfg.left_arm_control_part} " + f"object waypoint {waypoint_idx}." + ) + return False, left_traj, right_traj + if not self.builder.all_envs_success(right_success): + logger.log_warning( + f"Failed to compute IK for {self.cfg.right_arm_control_part} " + f"object waypoint {waypoint_idx}." + ) + return False, left_traj, right_traj + left_traj[:, keyframe_col] = left_qpos + right_traj[:, keyframe_col] = right_qpos + left_qpos_seed = left_qpos + right_qpos_seed = right_qpos + + return ( + True, + self._interpolate_qpos_keyframes(left_traj, keyframe_indices, n_waypoints), + self._interpolate_qpos_keyframes(right_traj, keyframe_indices, n_waypoints), + ) + + def execute( + self, target: CoordinatedPickmentTarget, state: WorldState + ) -> ActionResult: + ( + object_initial_pose, + object_target_pose, + left_grasp_xpos, + right_grasp_xpos, + left_target_xpos, + right_target_xpos, + held_state, + ) = self._resolve_target(target) + left_start_qpos, right_start_qpos = self._resolve_dual_arm_start(state) + segments = self._compute_segment_lengths() + + left_pre_grasp_xpos = self._compute_pre_grasp_xpos(left_grasp_xpos) + right_pre_grasp_xpos = self._compute_pre_grasp_xpos(right_grasp_xpos) + left_approach_targets = torch.stack( + [left_pre_grasp_xpos, left_grasp_xpos], dim=1 + ) + right_approach_targets = torch.stack( + [right_pre_grasp_xpos, right_grasp_xpos], dim=1 + ) + ok, left_approach_traj = self._plan_named_arm_trajectory( + self.cfg.left_arm_control_part, + left_start_qpos, + left_approach_targets, + segments["approach"], + ) + if not ok: + return self._fail(state) + ok, right_approach_traj = self._plan_named_arm_trajectory( + self.cfg.right_arm_control_part, + right_start_qpos, + right_approach_targets, + segments["approach"], + ) + if not ok: + return self._fail(state) + + left_grasp_qpos = left_approach_traj[:, -1] + right_grasp_qpos = right_approach_traj[:, -1] + approach_trajectory = self._assemble_phase( + state, + left_approach_traj, + right_approach_traj, + self._repeat_qpos(self.left_hand_open_qpos, segments["approach"]), + self._repeat_qpos(self.right_hand_open_qpos, segments["approach"]), + ) + + close_trajectory = self._assemble_phase( + state, + self._repeat_qpos(left_grasp_qpos, segments["close"]), + self._repeat_qpos(right_grasp_qpos, segments["close"]), + self._interpolate_qpos( + self.left_hand_open_qpos, + self.left_hand_close_qpos, + segments["close"], + ), + self._interpolate_qpos( + self.right_hand_open_qpos, + self.right_hand_close_qpos, + segments["close"], + ), + ) + + lift_object_pose = self.builder.apply_local_offset( + object_initial_pose, + torch.tensor([0.0, 0.0, self.cfg.lift_height], device=self.device), + ) + lift_object_traj = self._interpolate_object_pose( + object_initial_pose, + lift_object_pose, + segments["lift"], + include_orientation=False, + ) + ok, left_lift_traj, right_lift_traj = self._plan_synchronized_object_motion( + left_grasp_qpos, + right_grasp_qpos, + lift_object_traj, + held_state.left_object_to_eef, + held_state.right_object_to_eef, + ) + if not ok: + return self._fail(state) + + left_lift_qpos = left_lift_traj[:, -1] + right_lift_qpos = right_lift_traj[:, -1] + lift_trajectory = self._assemble_phase( + state, + left_lift_traj, + right_lift_traj, + self._repeat_qpos(self.left_hand_close_qpos, segments["lift"]), + self._repeat_qpos(self.right_hand_close_qpos, segments["lift"]), + ) + + move_object_traj = self._interpolate_object_pose( + lift_object_pose, + object_target_pose, + segments["move"], + include_orientation=True, + ) + ok, left_move_traj, right_move_traj = self._plan_synchronized_object_motion( + left_lift_qpos, + right_lift_qpos, + move_object_traj, + held_state.left_object_to_eef, + held_state.right_object_to_eef, + ) + if not ok: + return self._fail(state) + + left_target_qpos = left_move_traj[:, -1] + right_target_qpos = right_move_traj[:, -1] + move_trajectory = self._assemble_phase( + state, + left_move_traj, + right_move_traj, + self._repeat_qpos(self.left_hand_close_qpos, segments["move"]), + self._repeat_qpos(self.right_hand_close_qpos, segments["move"]), + ) + + hold_trajectory = torch.empty( + (self.n_envs, 0, self.robot_dof), dtype=torch.float32, device=self.device + ) + if segments["hold"] > 0: + hold_trajectory = self._assemble_phase( + state, + self._repeat_qpos(left_target_qpos, segments["hold"]), + self._repeat_qpos(right_target_qpos, segments["hold"]), + self._repeat_qpos(self.left_hand_close_qpos, segments["hold"]), + self._repeat_qpos(self.right_hand_close_qpos, segments["hold"]), + ) + + full = torch.cat( + [ + approach_trajectory, + close_trajectory, + lift_trajectory, + move_trajectory, + hold_trajectory, + ], + dim=1, + ) + coordinated_held_object = CoordinatedHeldObjectState( + semantics=held_state.semantics, + left_object_to_eef=held_state.left_object_to_eef, + right_object_to_eef=held_state.right_object_to_eef, + left_grasp_xpos=left_target_xpos, + right_grasp_xpos=right_target_xpos, + ) + return ActionResult( + success=True, + trajectory=full, + next_state=WorldState( + last_qpos=full[:, -1, :].clone(), + held_object=None, + coordinated_held_object=coordinated_held_object, + ), + ) + + # ============================================================================= # Press # ============================================================================= @@ -908,6 +1693,7 @@ def execute(self, target: EndEffectorPoseTarget, state: WorldState) -> ActionRes next_state=WorldState( last_qpos=full[:, -1, :].clone(), held_object=state.held_object, + coordinated_held_object=state.coordinated_held_object, ), ) @@ -942,6 +1728,8 @@ def _fail(self, state: WorldState) -> ActionResult: __all__ = [ + "CoordinatedPickment", + "CoordinatedPickmentCfg", "MoveEndEffector", "MoveEndEffectorCfg", "MoveJoints", diff --git a/embodichain/lab/sim/atomic_actions/core.py b/embodichain/lab/sim/atomic_actions/core.py index 0f9f29094..4c937a0e8 100644 --- a/embodichain/lab/sim/atomic_actions/core.py +++ b/embodichain/lab/sim/atomic_actions/core.py @@ -119,12 +119,33 @@ class HeldObjectPoseTarget: """(4, 4) or (n_envs, 4, 4) target pose for the held object.""" +@dataclass(frozen=True) +class CoordinatedPickmentTarget: + """Object-centric target for picking and moving one object with two hands.""" + + object_target_pose: torch.Tensor + """Target pose for the shared object, shape (4, 4) or (n_envs, 4, 4).""" + + object_semantics: ObjectSemantics + """Semantic description of the shared object.""" + + left_object_to_eef: torch.Tensor + """Transform from object frame to left end-effector frame.""" + + right_object_to_eef: torch.Tensor + """Transform from object frame to right end-effector frame.""" + + object_initial_pose: torch.Tensor | None = None + """Optional initial object pose. Defaults to ``object_semantics.entity`` pose.""" + + Target = ( EndEffectorPoseTarget | JointPositionTarget | NamedJointPositionTarget | GraspTarget | HeldObjectPoseTarget + | CoordinatedPickmentTarget ) @@ -147,6 +168,26 @@ class HeldObjectState: """Batched end-effector pose used to grasp the object, shape [n_envs, 4, 4].""" +@dataclass +class CoordinatedHeldObjectState: + """State of a single object jointly held by two robot hands.""" + + semantics: ObjectSemantics + """Semantic object currently held by the two grippers.""" + + left_object_to_eef: torch.Tensor + """Transform from object frame to left end-effector frame, shape [n_envs, 4, 4].""" + + right_object_to_eef: torch.Tensor + """Transform from object frame to right end-effector frame, shape [n_envs, 4, 4].""" + + left_grasp_xpos: torch.Tensor + """Left end-effector grasp pose for the shared object, shape [n_envs, 4, 4].""" + + right_grasp_xpos: torch.Tensor + """Right end-effector grasp pose for the shared object, shape [n_envs, 4, 4].""" + + @dataclass class WorldState: """State the engine threads through a sequence of actions.""" @@ -157,6 +198,9 @@ class WorldState: held_object: HeldObjectState | None = None """Object currently held by the gripper, or None.""" + coordinated_held_object: CoordinatedHeldObjectState | None = None + """Object currently held by two grippers, or None.""" + @dataclass class ActionResult: @@ -232,6 +276,8 @@ def execute(self, target: Target, state: WorldState) -> ActionResult: "ActionCfg", "ActionResult", "AtomicAction", + "CoordinatedHeldObjectState", + "CoordinatedPickmentTarget", "GraspTarget", "HeldObjectState", "HeldObjectPoseTarget", diff --git a/scripts/tutorials/atomic_action/coordinated_pickment.py b/scripts/tutorials/atomic_action/coordinated_pickment.py new file mode 100644 index 000000000..e19de0885 --- /dev/null +++ b/scripts/tutorials/atomic_action/coordinated_pickment.py @@ -0,0 +1,873 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +"""Demonstrate dual-arm coordinated pickment with selectable object meshes. + +The two UR5 arms pinch opposite sides of one object, lift it together, and move +the object to an object-centric target pose while both grippers stay closed. +""" + +from __future__ import annotations + +import argparse +import math +import os +import sys +import time +from dataclasses import dataclass +from pathlib import Path + +_REPO_ROOT = Path(__file__).resolve().parents[3] +if str(_REPO_ROOT) not in sys.path: + sys.path.insert(0, str(_REPO_ROOT)) + +import numpy as np +import torch +from scipy.spatial.transform import Rotation as SciRotation + +from embodichain.lab.gym.utils.gym_utils import add_env_launcher_args_to_parser +from embodichain.lab.sim import SimulationManager, SimulationManagerCfg +from embodichain.lab.sim.atomic_actions import ( + Affordance, + CoordinatedPickment, + CoordinatedPickmentCfg, + CoordinatedPickmentTarget, + ObjectSemantics, + WorldState, +) +from embodichain.lab.sim.cfg import ( + JointDrivePropertiesCfg, + LightCfg, + RenderCfg, + RigidBodyAttributesCfg, + RigidObjectCfg, + RobotCfg, + URDFCfg, +) +from embodichain.lab.sim.objects import RigidObject, Robot +from embodichain.lab.sim.planners import MotionGenCfg, MotionGenerator, ToppraPlannerCfg +from embodichain.lab.sim.shapes import MeshCfg +from embodichain.lab.sim.solvers import PytorchSolverCfg +from embodichain.utils import logger +from embodichain.utils.math import matrix_from_euler +from scripts.tutorials.atomic_action.tutorial_utils import ( + draw_axis_marker, + get_tutorial_window_size, + start_auto_play_recording, + stop_auto_play_recording, +) + +DEFAULT_MESH_FRAME_CORRECTION_EULER_DEG = (-90.0, 0.0, 0.0) + + +def transform_baseline_pose( + init_pos: tuple[float, float, float], + init_rot: tuple[float, float, float], + *, + z_offset: float = 0.0, + mesh_frame_correction_euler_deg: tuple[float, float, float] = ( + DEFAULT_MESH_FRAME_CORRECTION_EULER_DEG + ), + world_yaw_correction_deg: float = 0.0, +) -> tuple[tuple[float, float, float], tuple[float, float, float]]: + """Apply mesh-frame correction while preserving baseline world placement.""" + pos = np.asarray(init_pos, dtype=np.float64) + pos[2] += z_offset + rot = ( + SciRotation.from_euler("Z", world_yaw_correction_deg, degrees=True) + * SciRotation.from_euler("XYZ", init_rot, degrees=True) + * SciRotation.from_euler("XYZ", mesh_frame_correction_euler_deg, degrees=True) + ).as_euler("XYZ", degrees=True) + return tuple(float(value) for value in pos), tuple(float(value) for value in rot) + + +ARM_URDF_PATH = "UniversalRobots/UR5/UR5.urdf" +GRIPPER_URDF_PATH = "DH_PGI_140_80/DH_PGI_140_80.urdf" +PICKMENT_ASSET_ROOT = "CoordinatedPlacementAndPickment" +TABLE_MESH_PATH = f"{PICKMENT_ASSET_ROOT}/table.glb" +GRIPPER_TCP_Z = 0.121 +ROBOT_INIT_POS = (1.95, 0.0, 0.1) +ROBOT_INIT_ROT = (0.0, 0.0, -90.0) +LEFT_ARM_HOME = (0.0, 0.0, -1.57, -1.57, 1.57, 1.57) +RIGHT_ARM_HOME = (-1.57, -1.57, -1.57, -1.57, 0.0, 0.0) +TABLE_TOP_Z = 0.65 +BASELINE_TABLE_TOP_Z = 0.3621708124799265 +SCENE_Z_OFFSET = TABLE_TOP_Z - BASELINE_TABLE_TOP_Z +BASELINE_TABLE_INIT_POS = ( + 0.00014585733079742588, + 0.00023304896730074557, + -0.019599792839044783, +) +BASELINE_TABLE_INIT_ROT = ( + 0.0001074673904926984, + 0.00865572768366991, + -90.6562109309317, +) +TABLE_INIT_POS, TABLE_INIT_ROT = transform_baseline_pose( + BASELINE_TABLE_INIT_POS, + BASELINE_TABLE_INIT_ROT, + z_offset=SCENE_Z_OFFSET, +) +PICKMENT_RECORD_LOOK_AT = ( + (-0.25, 0.02, 2.5), + (0.0, 0.02, 0.75), + (0.0, 0.0, 1.0), +) + + +@dataclass(frozen=True) +class PickmentObjectPreset: + """Configuration for an object used by the coordinated pickment demo.""" + + label: str + mesh_path: str + init_xy: tuple[float, float] + init_rot: tuple[float, float, float] + table_clearance: float + body_scale: tuple[float, float, float] + grasp_end_margin_ratio: float + grasp_z_clearance: float + target_translation: tuple[float, float, float] + target_world_yaw_deg: float + hand_close_qpos: float + grasp_z_ratio: float | None = None + + +OBJECT_PRESETS = { + "pencil": PickmentObjectPreset( + label="pencil", + mesh_path=f"{PICKMENT_ASSET_ROOT}/pencil.glb", + init_xy=(-0.02, 0.02), + # Rotate the imported pencil from its default upright orientation to a tabletop pose. + init_rot=(90.0, 0.0, 0.0), + table_clearance=0.008, + body_scale=(2.0, 2.0, 2.0), + grasp_end_margin_ratio=0.12, + grasp_z_clearance=0.015, + target_translation=(-0.22, -0.04, 0.16), + target_world_yaw_deg=0.0, + hand_close_qpos=0.026, + ), + "pot": PickmentObjectPreset( + label="pot", + mesh_path=f"{PICKMENT_ASSET_ROOT}/pot.glb", + init_xy=(-0.02, 0.02), + init_rot=(-90.0, 90.0, 0.0), + table_clearance=0.008, + body_scale=(2.0, 2.0, 2.0), + grasp_end_margin_ratio=0.08, + grasp_z_clearance=0.01, + target_translation=(-0.12, -0.03, 0.12), + target_world_yaw_deg=0.0, + hand_close_qpos=0.026, + grasp_z_ratio=0.55, + ), +} +PICKMENT_SAMPLE_INTERVAL = 96 +PICKMENT_OBJECT_MOTION_KEYFRAMES = 6 +PICKMENT_PRE_GRASP_DISTANCE = 0.11 +PICKMENT_LIFT_HEIGHT = 0.10 +PICKMENT_HAND_INTERP_STEPS = 10 +PICKMENT_HOLD_STEPS = 4 +TRAJECTORY_SIM_STEPS = 4 + + +def parse_arguments() -> argparse.Namespace: + """Parse command-line arguments for the demo.""" + parser = argparse.ArgumentParser(description="Dual-arm coordinated pickment demo") + add_env_launcher_args_to_parser(parser) + parser.set_defaults(device="cuda", renderer="hybrid") + parser.add_argument( + "--diagnose_plan", + action="store_true", + help="Plan and print diagnostics without playing the trajectory.", + ) + parser.add_argument( + "--debug_state", + action="store_true", + help="Log hand targets and object poses during execution.", + ) + parser.add_argument( + "--auto_play", + action="store_true", + help="Run the viewer demo without waiting for keyboard input.", + ) + parser.add_argument( + "--headless_play", + action="store_true", + help="Execute planned trajectories without opening the viewer window.", + ) + parser.add_argument( + "--object", + choices=sorted(OBJECT_PRESETS), + default="pencil", + help="Object mesh to grasp in the coordinated pickment demo.", + ) + parser.add_argument( + "--no_vis_eef_axis", + action="store_true", + help="Do not draw the pickment target/grasp coordinate frames before planning.", + ) + return parser.parse_args() + + +def get_cached_data_path(data_path: str) -> str: + """Resolve an asset path from the local cache before importing data helpers.""" + if os.path.isabs(data_path): + return data_path + + data_root = Path( + os.environ.get( + "EMBODICHAIN_DATA_ROOT", + str(Path.home() / ".cache" / "embodichain_data"), + ) + ) + candidates = ( + data_root / data_path, + data_root / "extract" / data_path, + ) + for candidate in candidates: + if candidate.exists(): + return str(candidate) + + from embodichain.data import get_data_path + + return get_data_path(data_path) + + +def rotation_z(yaw: float) -> np.ndarray: + """Build a 3x3 yaw rotation matrix.""" + cos_yaw = math.cos(yaw) + sin_yaw = math.sin(yaw) + return np.array( + [ + [cos_yaw, -sin_yaw, 0.0], + [sin_yaw, cos_yaw, 0.0], + [0.0, 0.0, 1.0], + ], + dtype=np.float32, + ) + + +def make_transform(xyz: tuple[float, float, float], yaw: float) -> np.ndarray: + """Build a homogeneous transform from translation and yaw.""" + transform = np.eye(4, dtype=np.float32) + transform[:3, :3] = rotation_z(yaw) + transform[:3, 3] = np.asarray(xyz, dtype=np.float32) + return transform + + +def initialize_simulation(args: argparse.Namespace) -> SimulationManager: + """Create the simulation manager and a light.""" + width, height = get_tutorial_window_size(args) + sim = SimulationManager( + SimulationManagerCfg( + width=width, + height=height, + headless=True, + sim_device=args.device, + render_cfg=RenderCfg(renderer=args.renderer), + physics_dt=1.0 / 100.0, + arena_space=3.0, + ) + ) + sim.add_light( + cfg=LightCfg( + uid="main_light", + color=(0.6, 0.6, 0.6), + intensity=30.0, + init_pos=(0.0, -0.4, 3.0), + ) + ) + return sim + + +def create_dual_ur5_robot(sim: SimulationManager) -> Robot: + """Create a dual-UR5 robot with one PGI gripper on each arm.""" + arm_urdf_path = get_cached_data_path(ARM_URDF_PATH) + gripper_urdf_path = get_cached_data_path(GRIPPER_URDF_PATH) + tcp = [ + [1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 1.0, GRIPPER_TCP_Z], + [0.0, 0.0, 0.0, 1.0], + ] + cfg = RobotCfg( + uid="DualUR5CoordinatedPickment", + urdf_cfg=URDFCfg( + components=[ + { + "component_type": "left_arm", + "urdf_path": arm_urdf_path, + "transform": make_transform((-0.3, -1.45, 0.4), np.pi / 2), + }, + { + "component_type": "right_arm", + "urdf_path": arm_urdf_path, + "transform": make_transform((0.3, -1.45, 0.4), np.pi / 2), + }, + {"component_type": "left_hand", "urdf_path": gripper_urdf_path}, + {"component_type": "right_hand", "urdf_path": gripper_urdf_path}, + ], + fname="dual_ur5_coordinated_pickment", + name_case={"joint": "upper", "link": "lower"}, + ), + drive_pros=JointDrivePropertiesCfg( + stiffness={ + "LEFT_JOINT[0-9]": 1e4, + "RIGHT_JOINT[0-9]": 1e4, + "LEFT_GRIPPER_FINGER[1-2]_JOINT_1": 1e3, + "RIGHT_GRIPPER_FINGER[1-2]_JOINT_1": 1e3, + }, + damping={ + "LEFT_JOINT[0-9]": 1e3, + "RIGHT_JOINT[0-9]": 1e3, + "LEFT_GRIPPER_FINGER[1-2]_JOINT_1": 1e2, + "RIGHT_GRIPPER_FINGER[1-2]_JOINT_1": 1e2, + }, + max_effort={ + "LEFT_JOINT[0-9]": 1e5, + "RIGHT_JOINT[0-9]": 1e5, + "LEFT_GRIPPER_FINGER[1-2]_JOINT_1": 1e4, + "RIGHT_GRIPPER_FINGER[1-2]_JOINT_1": 1e4, + }, + drive_type="force", + ), + control_parts={ + "left_arm": ["LEFT_JOINT[0-9]"], + "right_arm": ["RIGHT_JOINT[0-9]"], + "dual_arm": ["LEFT_JOINT[0-9]", "RIGHT_JOINT[0-9]"], + "left_hand": ["LEFT_GRIPPER_FINGER1_JOINT_1"], + "right_hand": ["RIGHT_GRIPPER_FINGER1_JOINT_1"], + }, + solver_cfg={ + "left_arm": PytorchSolverCfg( + end_link_name="left_ee_link", + root_link_name="left_base_link", + tcp=tcp, + num_samples=30, + ), + "right_arm": PytorchSolverCfg( + end_link_name="right_ee_link", + root_link_name="right_base_link", + tcp=tcp, + num_samples=30, + ), + }, + init_pos=list(ROBOT_INIT_POS), + init_rot=list(ROBOT_INIT_ROT), + init_qpos=list(LEFT_ARM_HOME) + list(RIGHT_ARM_HOME) + [0.0, 0.0, 0.0, 0.0], + ) + return sim.add_robot(cfg=cfg) + + +def create_table(sim: SimulationManager) -> RigidObject: + """Create the table mesh used by the pickment scene.""" + return sim.add_rigid_object( + cfg=RigidObjectCfg( + uid="table", + shape=MeshCfg(fpath=get_cached_data_path(TABLE_MESH_PATH)), + attrs=RigidBodyAttributesCfg( + mass=10.0, + dynamic_friction=0.9, + static_friction=0.95, + restitution=0.01, + ), + body_type="kinematic", + init_pos=list(TABLE_INIT_POS), + init_rot=list(TABLE_INIT_ROT), + ) + ) + + +def create_pickment_object( + sim: SimulationManager, + preset: PickmentObjectPreset, +) -> RigidObject: + """Create the selected object mesh on the table.""" + obj = sim.add_rigid_object( + cfg=RigidObjectCfg( + uid=preset.label, + shape=MeshCfg( + fpath=get_cached_data_path(preset.mesh_path), compute_uv=False + ), + attrs=RigidBodyAttributesCfg( + mass=0.01, + dynamic_friction=0.97, + static_friction=0.99, + angular_damping=1.0, + linear_damping=0.5, + contact_offset=0.001, + rest_offset=0.0, + restitution=0.01, + min_position_iters=32, + min_velocity_iters=8, + max_depenetration_velocity=2.0, + ), + max_convex_hull_num=16, + init_pos=[preset.init_xy[0], preset.init_xy[1], TABLE_TOP_Z], + init_rot=list(preset.init_rot), + body_scale=preset.body_scale, + ) + ) + obj.cfg.init_pos = compute_tabletop_init_pos(obj, preset) + obj.reset() + return obj + + +def settle_object(sim: SimulationManager, obj: RigidObject, step: int = 5) -> None: + """Settle an object before planning.""" + if sim.device.type == "cuda": + sim.init_gpu_physics() + obj.reset() + if step > 0: + sim.update(step=step) + obj.clear_dynamics() + + +def create_object_semantics(obj: RigidObject, label: str) -> ObjectSemantics: + """Create minimal object semantics for manually specified grasps.""" + return ObjectSemantics( + label=label, + geometry={}, + affordance=Affordance(object_label=label), + entity=obj, + ) + + +def get_hand_open_close_qpos( + robot: Robot, + hand_control_part: str, + device: torch.device, + close_qpos: float, +) -> tuple[torch.Tensor, torch.Tensor]: + """Get open and close qpos for a PGI gripper control part.""" + limits = robot.get_qpos_limits(name=hand_control_part)[0].to( + device=device, dtype=torch.float32 + ) + hand_open = limits[:, 0] + hand_close = torch.clamp( + torch.full_like(limits[:, 1], close_qpos), + min=limits[:, 0], + max=limits[:, 1], + ) + return hand_open, hand_close + + +def get_local_vertices(obj: RigidObject) -> torch.Tensor: + """Get scaled local mesh vertices.""" + return obj.get_vertices(env_ids=[0], scale=True)[0] + + +def compute_local_bounds(vertices: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """Compute local mesh AABB from scaled vertices.""" + return vertices.min(dim=0).values, vertices.max(dim=0).values + + +def compute_tabletop_init_pos( + obj: RigidObject, + preset: PickmentObjectPreset, +) -> tuple[float, float, float]: + """Place an object so its rotated mesh bottom sits on the table.""" + vertices = get_local_vertices(obj) + rot = torch.as_tensor(preset.init_rot, dtype=torch.float32, device=vertices.device) + rot = rot.unsqueeze(0) * torch.pi / 180.0 + upright_rot = matrix_from_euler(rot, "XYZ")[0] + rotated_vertices = vertices @ upright_rot.T + bottom_z = rotated_vertices[:, 2].min().item() + z = TABLE_TOP_Z + preset.table_clearance - bottom_z + return (preset.init_xy[0], preset.init_xy[1], z) + + +def invert_pose(pose: torch.Tensor) -> torch.Tensor: + """Invert batched homogeneous transforms.""" + inv_pose = pose.clone() + rot_t = pose[:, :3, :3].transpose(1, 2) + inv_pose[:, :3, :3] = rot_t + inv_pose[:, :3, 3] = -torch.bmm(rot_t, pose[:, :3, 3:4]).squeeze(-1) + return inv_pose + + +def transform_points(pose: torch.Tensor, points: torch.Tensor) -> torch.Tensor: + """Transform local points by a homogeneous pose.""" + return points @ pose[:3, :3].transpose(0, 1) + pose[:3, 3] + + +def compute_world_bounds( + object_pose: torch.Tensor, + local_vertices: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + """Compute world AABB from transformed local mesh vertices.""" + world_vertices = transform_points(object_pose, local_vertices) + return world_vertices.min(dim=0).values, world_vertices.max(dim=0).values + + +def normalize_vector(vector: torch.Tensor, fallback: torch.Tensor) -> torch.Tensor: + """Normalize a vector with a deterministic fallback for degenerate cases.""" + norm = torch.linalg.norm(vector) + if norm < 1e-6: + return fallback.to(device=vector.device, dtype=vector.dtype) + return vector / norm + + +def rotate_pose_about_world_z(pose: torch.Tensor, yaw_deg: float) -> torch.Tensor: + """Rotate pose orientation about world Z while preserving translation.""" + yaw = math.radians(yaw_deg) + rot = torch.eye(3, dtype=pose.dtype, device=pose.device) + rot[0, 0] = math.cos(yaw) + rot[0, 1] = -math.sin(yaw) + rot[1, 0] = math.sin(yaw) + rot[1, 1] = math.cos(yaw) + rotated_pose = pose.clone() + rotated_pose[:3, :3] = rot @ pose[:3, :3] + return rotated_pose + + +def build_object_grasp_poses( + object_pose: torch.Tensor, + local_vertices: torch.Tensor, + preset: PickmentObjectPreset, + device: torch.device, +) -> tuple[torch.Tensor, torch.Tensor]: + """Build left/right TCP poses that pinch opposite sides of the object.""" + local_min, local_max = compute_local_bounds(local_vertices) + extents = local_max - local_min + long_axis_idx = int(torch.argmax(extents).item()) + axis_local = torch.zeros(3, dtype=torch.float32, device=device) + axis_local[long_axis_idx] = 1.0 + long_axis = normalize_vector( + object_pose[:3, :3] @ axis_local, + torch.tensor([1.0, 0.0, 0.0], dtype=torch.float32, device=device), + ) + + local_center = 0.5 * (local_min + local_max) + margin = extents[long_axis_idx] * preset.grasp_end_margin_ratio + left_local = local_center.clone() + right_local = local_center.clone() + left_local[long_axis_idx] = local_min[long_axis_idx] + margin + right_local[long_axis_idx] = local_max[long_axis_idx] - margin + + world_min, world_max = compute_world_bounds(object_pose, local_vertices) + left_position = object_pose[:3, 3] + object_pose[:3, :3] @ left_local.to(device) + right_position = object_pose[:3, 3] + object_pose[:3, :3] @ right_local.to(device) + if preset.grasp_z_ratio is None: + grasp_z = world_max[2] + preset.grasp_z_clearance + else: + grasp_z = ( + world_min[2] + + (world_max[2] - world_min[2]) * preset.grasp_z_ratio + + preset.grasp_z_clearance + ) + left_position[2] = grasp_z + right_position[2] = grasp_z + + z_axis = torch.tensor([0.0, 0.0, -1.0], dtype=torch.float32, device=device) + x_axis = normalize_vector( + torch.cross(long_axis, z_axis, dim=0), + torch.tensor([1.0, 0.0, 0.0], dtype=torch.float32, device=device), + ) + y_axis = normalize_vector(torch.cross(z_axis, x_axis, dim=0), long_axis) + + left_pose = torch.eye(4, dtype=torch.float32, device=device) + left_pose[:3, 0] = x_axis + left_pose[:3, 1] = y_axis + left_pose[:3, 2] = z_axis + left_pose[:3, 3] = left_position + + right_pose = torch.eye(4, dtype=torch.float32, device=device) + right_pose[:3, 0] = -x_axis + right_pose[:3, 1] = -y_axis + right_pose[:3, 2] = z_axis + right_pose[:3, 3] = right_position + return left_pose, right_pose + + +def build_object_target_pose( + object_pose: torch.Tensor, + object_vertices: torch.Tensor, + preset: PickmentObjectPreset, + device: torch.device, +) -> torch.Tensor: + """Build the target pose for the whole object.""" + pose = rotate_pose_about_world_z( + object_pose.clone().to(device=device, dtype=torch.float32), + preset.target_world_yaw_deg, + ) + pose[:3, 3] += torch.tensor( + preset.target_translation, dtype=torch.float32, device=device + ) + bottom_z = compute_world_bounds(pose, object_vertices)[0][2] + pose[2, 3] += TABLE_TOP_Z + preset.table_clearance + 0.10 - bottom_z + return pose + + +def format_tensor(tensor: torch.Tensor) -> str: + """Format tensor values for compact logging.""" + rounded = (tensor.detach().cpu() * 10000.0).round() / 10000.0 + return str(rounded.tolist()) + + +def log_action_plan( + robot: Robot, + action_name: str, + traj: torch.Tensor, + joint_ids: list[int], + segments: dict[str, int] | None = None, +) -> None: + """Log common action plan details.""" + joint_names = [robot.joint_names[joint_id] for joint_id in joint_ids] + logger.log_info(f"{action_name} joint ids: {joint_ids}") + logger.log_info(f"{action_name} joint names: {joint_names}") + logger.log_info(f"{action_name} trajectory shape: {tuple(traj.shape)}") + if segments is not None: + logger.log_info(f"{action_name} trajectory segments: {segments}") + + +def log_scene_targets( + object_label: str, + object_pose: torch.Tensor, + target_pose: torch.Tensor, + left_grasp_pose: torch.Tensor, + right_grasp_pose: torch.Tensor, +) -> None: + """Log compact object and grasp target positions.""" + logger.log_info( + "pickment scene: " + f"object={object_label}, " + f"object_origin={format_tensor(object_pose[:3, 3])}, " + f"target_origin={format_tensor(target_pose[:3, 3])}, " + f"left_grasp={format_tensor(left_grasp_pose[:3, 3])}, " + f"right_grasp={format_tensor(right_grasp_pose[:3, 3])}" + ) + + +def draw_pickment_target_axes( + sim: SimulationManager, + object_target_pose: torch.Tensor, + left_grasp_pose: torch.Tensor, + right_grasp_pose: torch.Tensor, +) -> None: + """Draw semantic axes for the target object pose and two grasp TCP poses.""" + draw_axis_marker( + sim, + "coordinated_pickment_object_target_axis", + object_target_pose, + axis_len=0.12, + axis_size=0.005, + ) + draw_axis_marker( + sim, + "coordinated_pickment_left_grasp_axis", + left_grasp_pose, + axis_len=0.07, + axis_size=0.0035, + ) + draw_axis_marker( + sim, + "coordinated_pickment_right_grasp_axis", + right_grasp_pose, + axis_len=0.07, + axis_size=0.0035, + ) + + +def log_execution_state( + robot: Robot, + obj: RigidObject, + step_idx: int, + total_steps: int, +) -> None: + """Log hand and object state during execution.""" + object_pose = obj.get_local_pose(to_matrix=True) + left_hand = robot.get_qpos(name="left_hand") + right_hand = robot.get_qpos(name="right_hand") + logger.log_info( + f"step={step_idx}/{total_steps - 1}, " + f"left_hand={format_tensor(left_hand[0])}, " + f"right_hand={format_tensor(right_hand[0])}, " + f"{obj.uid}_pos={format_tensor(object_pose[0, :3, 3])}" + ) + + +def execute_trajectory( + sim: SimulationManager, + robot: Robot, + traj: torch.Tensor, + obj: RigidObject, + debug_state: bool, +) -> None: + """Play a planned trajectory in simulation.""" + total_steps = traj.shape[1] + log_stride = max(1, total_steps // 10) + for i in range(total_steps): + robot.set_qpos(traj[:, i, :]) + sim.update(step=TRAJECTORY_SIM_STEPS) + if debug_state and (i % log_stride == 0 or i == total_steps - 1): + log_execution_state(robot, obj, i, total_steps) + time.sleep(1e-2) + + +def run_coordinated_pickment_demo( + args: argparse.Namespace, + sim: SimulationManager, + robot: Robot, +) -> None: + """Plan and optionally execute coordinated object pickment.""" + preset = OBJECT_PRESETS[args.object] + create_table(sim) + obj = create_pickment_object(sim, preset) + settle_object(sim, obj, step=0) + object_pose = obj.get_local_pose(to_matrix=True)[0].to( + device=sim.device, dtype=torch.float32 + ) + object_vertices = get_local_vertices(obj) + object_semantics = create_object_semantics(obj, preset.label) + motion_gen = MotionGenerator( + cfg=MotionGenCfg(planner_cfg=ToppraPlannerCfg(robot_uid=robot.uid)) + ) + + left_open, left_close = get_hand_open_close_qpos( + robot, "left_hand", sim.device, preset.hand_close_qpos + ) + right_open, right_close = get_hand_open_close_qpos( + robot, "right_hand", sim.device, preset.hand_close_qpos + ) + pickment_action = CoordinatedPickment( + motion_generator=motion_gen, + cfg=CoordinatedPickmentCfg( + control_part="dual_arm", + left_arm_control_part="left_arm", + right_arm_control_part="right_arm", + left_hand_control_part="left_hand", + right_hand_control_part="right_hand", + left_hand_open_qpos=left_open, + left_hand_close_qpos=left_close, + right_hand_open_qpos=right_open, + right_hand_close_qpos=right_close, + pre_grasp_distance=PICKMENT_PRE_GRASP_DISTANCE, + lift_height=PICKMENT_LIFT_HEIGHT, + sample_interval=PICKMENT_SAMPLE_INTERVAL, + hand_interp_steps=PICKMENT_HAND_INTERP_STEPS, + hold_steps=PICKMENT_HOLD_STEPS, + object_motion_keyframes=PICKMENT_OBJECT_MOTION_KEYFRAMES, + ), + ) + + left_grasp_pose, right_grasp_pose = build_object_grasp_poses( + object_pose, + object_vertices, + preset, + sim.device, + ) + target_pose = build_object_target_pose( + object_pose, + object_vertices, + preset, + sim.device, + ) + log_scene_targets( + preset.label, + object_pose, + target_pose, + left_grasp_pose, + right_grasp_pose, + ) + if not args.no_vis_eef_axis: + draw_pickment_target_axes( + sim, + target_pose, + left_grasp_pose, + right_grasp_pose, + ) + + left_object_to_eef = torch.bmm( + invert_pose(object_pose.unsqueeze(0)), + left_grasp_pose.unsqueeze(0), + ) + right_object_to_eef = torch.bmm( + invert_pose(object_pose.unsqueeze(0)), + right_grasp_pose.unsqueeze(0), + ) + pickment_target = CoordinatedPickmentTarget( + object_target_pose=target_pose, + object_semantics=object_semantics, + left_object_to_eef=left_object_to_eef, + right_object_to_eef=right_object_to_eef, + object_initial_pose=object_pose, + ) + + wait_for_user = not args.auto_play and not args.headless_play + if not args.diagnose_plan and not args.headless_play: + sim.open_window() + if wait_for_user: + input("Inspect the scene, then press Enter to plan pickment...") + + start_time = time.time() + result = pickment_action.execute( + pickment_target, + WorldState(last_qpos=robot.get_qpos().clone()), + ) + logger.log_info( + f"Plan coordinated pickment cost time: {time.time() - start_time:.2f} seconds" + ) + if not result.success: + logger.log_warning("Failed to plan coordinated pickment trajectory.") + return + traj = result.trajectory + joint_ids = list(range(robot.dof)) + log_action_plan( + robot, + "coordinated_pickment", + traj, + joint_ids, + pickment_action.get_segment_lengths(), + ) + + if args.diagnose_plan: + return + + if wait_for_user: + input("Press Enter to execute coordinated pickment...") + recording_started = start_auto_play_recording( + sim, + args, + video_prefix=f"coordinated_pickment_{args.object}_auto_play", + look_at=PICKMENT_RECORD_LOOK_AT, + ) + try: + execute_trajectory( + sim, + robot, + traj, + obj, + args.debug_state, + ) + finally: + stop_auto_play_recording(sim, recording_started) + if wait_for_user: + input("Press Enter to exit the simulation...") + + +def main() -> None: + """Run the coordinated pickment demo.""" + args = parse_arguments() + sim = initialize_simulation(args) + robot = create_dual_ur5_robot(sim) + run_coordinated_pickment_demo(args, sim, robot) + + +if __name__ == "__main__": + main() diff --git a/tests/sim/atomic_actions/test_actions.py b/tests/sim/atomic_actions/test_actions.py index 99413c4eb..051efeaf0 100644 --- a/tests/sim/atomic_actions/test_actions.py +++ b/tests/sim/atomic_actions/test_actions.py @@ -27,6 +27,9 @@ ) from embodichain.lab.sim.atomic_actions.core import ( ActionResult, + AtomicAction, + CoordinatedHeldObjectState, + CoordinatedPickmentTarget, GraspTarget, HeldObjectState, HeldObjectPoseTarget, @@ -37,6 +40,8 @@ WorldState, ) from embodichain.lab.sim.atomic_actions.actions import ( + CoordinatedPickment, + CoordinatedPickmentCfg, MoveEndEffector, MoveEndEffectorCfg, MoveJoints, @@ -55,6 +60,8 @@ ARM_DOF = 6 HAND_DOF = 2 TOTAL_DOF = ARM_DOF + HAND_DOF +DUAL_ARM_DOF = 12 +DUAL_TOTAL_DOF = DUAL_ARM_DOF + 2 * HAND_DOF def _make_mock_robot(): @@ -113,6 +120,57 @@ def _make_mock_motion_generator(): return mg +def _make_dual_arm_mock_robot(): + robot = Mock() + robot.device = torch.device("cpu") + robot.dof = DUAL_TOTAL_DOF + + def get_qpos(name=None): + if name == "left_arm": + return torch.zeros(NUM_ENVS, ARM_DOF) + if name == "right_arm": + return torch.zeros(NUM_ENVS, ARM_DOF) + if name == "dual_arm": + return torch.zeros(NUM_ENVS, DUAL_ARM_DOF) + if name in ("left_hand", "right_hand"): + return torch.zeros(NUM_ENVS, HAND_DOF) + return torch.zeros(NUM_ENVS, DUAL_TOTAL_DOF) + + robot.get_qpos = get_qpos + + def get_joint_ids(name=None): + if name == "left_arm": + return list(range(0, ARM_DOF)) + if name == "right_arm": + return list(range(ARM_DOF, DUAL_ARM_DOF)) + if name == "dual_arm": + return list(range(DUAL_ARM_DOF)) + if name == "left_hand": + return list(range(DUAL_ARM_DOF, DUAL_ARM_DOF + HAND_DOF)) + if name == "right_hand": + return list(range(DUAL_ARM_DOF + HAND_DOF, DUAL_TOTAL_DOF)) + return list(range(DUAL_TOTAL_DOF)) + + robot.get_joint_ids = get_joint_ids + + def compute_ik(pose=None, name=None, joint_seed=None, qpos_seed=None): + seed = joint_seed if joint_seed is not None else qpos_seed + if seed is None: + seed = torch.zeros(NUM_ENVS, ARM_DOF) + offset = 0.1 if name == "left_arm" else 0.2 + return torch.ones(seed.shape[0], dtype=torch.bool), seed + offset + + robot.compute_ik = compute_ik + return robot + + +def _make_dual_arm_mock_motion_generator(): + mg = Mock() + mg.robot = _make_dual_arm_mock_robot() + mg.device = torch.device("cpu") + return mg + + def _hand_open(): return torch.zeros(HAND_DOF, dtype=torch.float32) @@ -591,3 +649,59 @@ def interpolate(trajectory, interp_num, device): last_qpos[:, :ARM_DOF], ) assert result.next_state.held_object is held + + +# --------------------------------------------------------------------------- +# CoordinatedPickment +# --------------------------------------------------------------------------- + + +class TestCoordinatedPickmentAction: + def setup_method(self): + self.mg = _make_dual_arm_mock_motion_generator() + + def test_target_type_is_coordinated_pickment_target(self): + assert CoordinatedPickment.TargetType is CoordinatedPickmentTarget + assert CoordinatedPickment.__bases__ == (AtomicAction,) + + def test_execute_returns_full_dof_trajectory_and_dual_held_state(self): + cfg = CoordinatedPickmentCfg( + left_hand_open_qpos=_hand_open(), + left_hand_close_qpos=_hand_close(), + right_hand_open_qpos=_hand_open(), + right_hand_close_qpos=_hand_close(), + sample_interval=30, + hand_interp_steps=4, + hold_steps=2, + object_motion_keyframes=3, + ) + action = CoordinatedPickment(self.mg, cfg) + sem = ObjectSemantics( + affordance=AntipodalAffordance(), geometry={}, label="pencil" + ) + state = WorldState(last_qpos=torch.zeros(NUM_ENVS, DUAL_TOTAL_DOF)) + result = action.execute( + CoordinatedPickmentTarget( + object_target_pose=torch.eye(4), + object_semantics=sem, + left_object_to_eef=torch.eye(4), + right_object_to_eef=torch.eye(4), + object_initial_pose=torch.eye(4), + ), + state, + ) + assert result.success is True + assert result.trajectory.shape == (NUM_ENVS, 30, DUAL_TOTAL_DOF) + assert torch.allclose( + result.trajectory[:, -1, action.left_hand_joint_ids], + _hand_close().unsqueeze(0).repeat(NUM_ENVS, 1), + ) + assert torch.allclose( + result.trajectory[:, -1, action.right_hand_joint_ids], + _hand_close().unsqueeze(0).repeat(NUM_ENVS, 1), + ) + assert isinstance( + result.next_state.coordinated_held_object, + CoordinatedHeldObjectState, + ) + assert result.next_state.held_object is None diff --git a/tests/sim/atomic_actions/test_core.py b/tests/sim/atomic_actions/test_core.py index 234ed6b4f..2c6e3d98e 100644 --- a/tests/sim/atomic_actions/test_core.py +++ b/tests/sim/atomic_actions/test_core.py @@ -27,6 +27,8 @@ from embodichain.lab.sim.atomic_actions.core import ( ActionCfg, ActionResult, + CoordinatedHeldObjectState, + CoordinatedPickmentTarget, GraspTarget, HeldObjectState, HeldObjectPoseTarget, @@ -86,6 +88,17 @@ def test_held_object_target_is_frozen(self): with pytest.raises(dataclasses.FrozenInstanceError): t.object_target_pose = torch.zeros(4, 4) # type: ignore[misc] + def test_coordinated_pickment_target_holds_object_offsets(self): + sem = ObjectSemantics(affordance=Affordance(), geometry={}, label="pencil") + target = CoordinatedPickmentTarget( + object_target_pose=torch.eye(4), + object_semantics=sem, + left_object_to_eef=torch.eye(4), + right_object_to_eef=torch.eye(4), + ) + assert target.object_semantics is sem + assert target.left_object_to_eef.shape == (4, 4) + class TestObjectSemantics: def test_does_not_mutate_affordance_geometry(self): @@ -122,6 +135,21 @@ def test_required_fields(self): assert s.grasp_xpos.shape == (1, 4, 4) +class TestCoordinatedHeldObjectState: + def test_required_fields(self): + sem = ObjectSemantics(affordance=Affordance(), geometry={}) + s = CoordinatedHeldObjectState( + semantics=sem, + left_object_to_eef=torch.eye(4).unsqueeze(0), + right_object_to_eef=torch.eye(4).unsqueeze(0), + left_grasp_xpos=torch.eye(4).unsqueeze(0), + right_grasp_xpos=torch.eye(4).unsqueeze(0), + ) + assert s.semantics is sem + assert s.left_object_to_eef.shape == (1, 4, 4) + assert s.right_grasp_xpos.shape == (1, 4, 4) + + class TestWorldState: def test_constructs_with_last_qpos_only(self): qpos = torch.zeros(2, 6) @@ -139,6 +167,21 @@ def test_carries_held_object(self): ws = WorldState(last_qpos=torch.zeros(1, 6), held_object=held) assert ws.held_object is held + def test_carries_coordinated_held_object(self): + sem = ObjectSemantics(affordance=Affordance(), geometry={}) + held = CoordinatedHeldObjectState( + semantics=sem, + left_object_to_eef=torch.eye(4).unsqueeze(0), + right_object_to_eef=torch.eye(4).unsqueeze(0), + left_grasp_xpos=torch.eye(4).unsqueeze(0), + right_grasp_xpos=torch.eye(4).unsqueeze(0), + ) + ws = WorldState( + last_qpos=torch.zeros(1, 14), + coordinated_held_object=held, + ) + assert ws.coordinated_held_object is held + class TestActionResult: def test_shape_contract(self):