diff --git a/docs/source/api_reference/embodichain/embodichain.lab.sim.atomic_actions.rst b/docs/source/api_reference/embodichain/embodichain.lab.sim.atomic_actions.rst index 181086c3..451d7db1 100644 --- a/docs/source/api_reference/embodichain/embodichain.lab.sim.atomic_actions.rst +++ b/docs/source/api_reference/embodichain/embodichain.lab.sim.atomic_actions.rst @@ -10,12 +10,16 @@ embodichain.lab.sim.atomic_actions Affordance InteractionPoints ObjectSemantics + HeldObjectState + MoveObjectTarget ActionCfg AtomicAction MoveActionCfg MoveAction PickUpActionCfg PickUpAction + MoveObjectActionCfg + MoveObjectAction PlaceActionCfg PlaceAction AtomicActionEngine @@ -37,6 +41,14 @@ Core :members: :show-inheritance: +.. autoclass:: HeldObjectState + :members: + :show-inheritance: + +.. autoclass:: MoveObjectTarget + :members: + :show-inheritance: + .. autoclass:: ActionCfg :members: :exclude-members: __init__, copy, replace, to_dict, validate @@ -66,6 +78,15 @@ Actions :members: :show-inheritance: +.. autoclass:: MoveObjectActionCfg + :members: + :exclude-members: __init__, copy, replace, to_dict, validate + :show-inheritance: + +.. autoclass:: MoveObjectAction + :members: + :show-inheritance: + .. autoclass:: PlaceActionCfg :members: :exclude-members: __init__, copy, replace, to_dict, validate diff --git a/docs/source/overview/sim/atomic_actions.md b/docs/source/overview/sim/atomic_actions.md index 979df571..878f549c 100644 --- a/docs/source/overview/sim/atomic_actions.md +++ b/docs/source/overview/sim/atomic_actions.md @@ -42,6 +42,13 @@ AtomicActionEngine ◄─────────────── PlanResult - `affordance` — *how* to interact with the object (e.g. antipodal grasp poses) - `entity` — a live reference to the simulation object, so actions can read its current pose +**`HeldObjectState`** is runtime state produced after a successful semantic pickup. It stores +the held object's semantics and object-to-end-effector transform so later actions can move the +object without recomputing the grasp. It is intentionally separate from `ObjectSemantics`, +which remains a reusable object description rather than per-execution robot state. + +**`MoveObjectTarget`** describes an object-centric target pose for an already-held object. + **`Affordance`** is a data class that encodes a specific interaction capability. The built-in affordance types are: | Class | Use case | @@ -67,6 +74,7 @@ The following actions are available out of the box: |---|---|---|---| | `MoveAction` | `MoveActionCfg` | `Tensor (4,4)` — EEF pose | Move arm to pose | | `PickUpAction` | `PickUpActionCfg` | `ObjectSemantics` or `Tensor (4,4)` | Approach → close gripper → lift | +| `MoveObjectAction` | `MoveObjectActionCfg` | `MoveObjectTarget` | Move held object and keep gripper closed | | `PlaceAction` | `PlaceActionCfg` | `Tensor (4,4)` — EEF release pose | Lower → open gripper → retract | ### `MoveAction` @@ -101,6 +109,26 @@ Three-phase grasp motion: *approach → close gripper → lift*. --- +### `MoveObjectAction` + +Moves a held object to an object-centric target pose while preserving the grasp. It consumes +the `HeldObjectState` produced by a prior semantic `PickUpAction`. + +`HeldObjectState` and `MoveObjectTarget` are intentionally kept separate from +`ObjectSemantics`: `ObjectSemantics` describes the object and affordances, while these +types describe runtime held-object state and action-specific targets. + +| Config field | Default | Description | +|---|---|---| +| `hand_close_qpos` | `None` | **Required.** Gripper closed joint positions | +| `hand_control_part` | `"hand"` | Robot control part for the gripper | +| `sample_interval` | `50` | Number of waypoints in the trajectory | + +**Target:** `MoveObjectTarget` or `dict` with `"object_target_pose"` containing a `torch.Tensor` +of shape `(4, 4)` or `(n_envs, 4, 4)`. + +--- + ### `PlaceAction` Three-phase release motion: *lower → open gripper → retract*. Mirrors `PickUpAction`. @@ -119,6 +147,8 @@ from embodichain.lab.sim.atomic_actions import ( ObjectSemantics, AntipodalAffordance, PickUpActionCfg, + MoveObjectActionCfg, + MoveObjectTarget, PlaceActionCfg, MoveActionCfg, ) @@ -131,6 +161,7 @@ pickup_cfg = PickUpActionCfg( hand_close_qpos=torch.tensor([0.025, 0.025]), ) place_cfg = PlaceActionCfg(...) +move_object_cfg = MoveObjectActionCfg(hand_close_qpos=torch.tensor([0.025, 0.025])) move_cfg = MoveActionCfg(control_part="arm") # 2. Build the engine — action order matches target_list order @@ -228,8 +259,10 @@ is_success, traj = engine.execute_static(target_list=[target_pose]) |---|---| | `torch.Tensor (4,4)` or `(n_envs,4,4)` | EEF pose, broadcast across envs | | `ObjectSemantics` | Passed directly to the action | +| `MoveObjectTarget` | Passed directly to `MoveObjectAction` | | `str` (object label) | Looked up in `SemanticAnalyzer` cache | | `dict` with `"pose"` key | Unwrapped to tensor | +| `dict` with `"object_target_pose"` key | Wrapped as `MoveObjectTarget` | | `dict` with `"label"` key | Analyzed via `SemanticAnalyzer` | --- @@ -239,3 +272,4 @@ is_success, traj = engine.execute_static(target_list=[target_pose]) - {doc}`planners/motion_generator` — the trajectory planner used by every action - {doc}`sim_robot` — how control parts and IK solvers are configured - Tutorial: `scripts/tutorials/sim/atomic_actions.py` +- Move object demo: `scripts/tutorials/atomic_action/move_object_atomic_actions.py` diff --git a/docs/source/tutorial/atomic_actions.rst b/docs/source/tutorial/atomic_actions.rst index 10b8e97c..6a7b0b05 100644 --- a/docs/source/tutorial/atomic_actions.rst +++ b/docs/source/tutorial/atomic_actions.rst @@ -15,7 +15,8 @@ Key Features - **Semantic-aware execution** — actions accept either a raw pose tensor or an ``ObjectSemantics`` descriptor that bundles affordance data (grasp poses, interaction points) with the simulation entity. -- **Three built-in primitives** — ``MoveAction``, ``PickUpAction``, and ``PlaceAction`` +- **Built-in primitives** — ``MoveAction``, ``PickUpAction``, ``MoveObjectAction``, + and ``PlaceAction`` cover the most common tabletop manipulation workflows out of the box. See the :ref:`supported_atomic_actions` table for configs and target types. - **Extensible registry** — custom actions can be registered globally with @@ -53,6 +54,7 @@ Setting up the engine from embodichain.lab.sim.atomic_actions import ( AtomicActionEngine, PickUpActionCfg, + MoveObjectActionCfg, PlaceActionCfg, MoveActionCfg, ) @@ -78,6 +80,11 @@ Setting up the engine hand_control_part="hand", lift_height=0.15, ) + move_object_cfg = MoveObjectActionCfg( + hand_close_qpos=hand_close, + control_part="arm", + hand_control_part="hand", + ) move_cfg = MoveActionCfg(control_part="arm") engine = AtomicActionEngine( @@ -139,6 +146,30 @@ Executing a pick-place-move sequence robot.set_qpos(trajectory[:, i]) sim.update(step=4) +Moving a held object +~~~~~~~~~~~~~~~~~~~~ + +``MoveObjectAction`` consumes the runtime ``HeldObjectState`` produced by a previous +semantic ``PickUpAction``. The target is object-centric, so the caller specifies where the +object should move, and the action converts that pose into an end-effector target while +keeping the gripper closed. + +.. code-block:: python + + from embodichain.lab.sim.atomic_actions import MoveObjectTarget + + engine = AtomicActionEngine( + motion_generator=motion_gen, + actions_cfg_list=[pickup_cfg, move_object_cfg], + ) + + object_target_pose = torch.eye(4, dtype=torch.float32, device=device) + object_target_pose[:3, 3] = torch.tensor([0.3, -0.2, 0.25], device=device) + + is_success, trajectory = engine.execute_static( + target_list=[semantics, MoveObjectTarget(object_target_pose=object_target_pose)] + ) + Registering custom actions ~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/embodichain/lab/sim/atomic_actions/__init__.py b/embodichain/lab/sim/atomic_actions/__init__.py index cf1e60ce..466419b3 100644 --- a/embodichain/lab/sim/atomic_actions/__init__.py +++ b/embodichain/lab/sim/atomic_actions/__init__.py @@ -26,14 +26,18 @@ AntipodalAffordance, InteractionPoints, ObjectSemantics, + HeldObjectState, + MoveObjectTarget, ActionCfg, AtomicAction, ) from .actions import ( MoveAction, + MoveObjectAction, PickUpAction, PlaceAction, MoveActionCfg, + MoveObjectActionCfg, PickUpActionCfg, PlaceActionCfg, ) @@ -47,16 +51,19 @@ __all__ = [ # Core classes "Affordance", - "GraspPose", "InteractionPoints", "ObjectSemantics", + "HeldObjectState", + "MoveObjectTarget", "ActionCfg", "AtomicAction", # Action implementations "MoveAction", + "MoveObjectAction", "PickUpAction", "PlaceAction", "MoveActionCfg", + "MoveObjectActionCfg", "PickUpActionCfg", "PlaceActionCfg", # Engine diff --git a/embodichain/lab/sim/atomic_actions/actions.py b/embodichain/lab/sim/atomic_actions/actions.py index 1aa8901a..f5d03d78 100644 --- a/embodichain/lab/sim/atomic_actions/actions.py +++ b/embodichain/lab/sim/atomic_actions/actions.py @@ -17,14 +17,22 @@ from __future__ import annotations import torch -from typing import Optional, Union, TYPE_CHECKING, Any +from typing import Optional, Union, TYPE_CHECKING from embodichain.lab.sim.planners import PlanResult, PlanState, MoveType from embodichain.lab.sim.planners.motion_generator import MotionGenOptions from embodichain.lab.sim.planners.toppra_planner import ToppraPlanOptions -from .core import AtomicAction, ObjectSemantics, AntipodalAffordance, ActionCfg +from .core import ( + AtomicAction, + ObjectSemantics, + AntipodalAffordance, + ActionCfg, + HeldObjectState, + MoveObjectTarget, +) from embodichain.utils import logger from embodichain.utils import configclass +from embodichain.utils.math import pose_inv from embodichain.lab.sim.utility.action_utils import interpolate_with_distance import numpy as np @@ -43,11 +51,8 @@ class MoveActionCfg(ActionCfg): @configclass -class GraspActionCfg(MoveActionCfg): - """Shared configuration for actions that involve gripper open/close motions.""" - - hand_open_qpos: torch.Tensor | None = None - """[hand_dof,] of float. Joint positions for open hand state.""" +class HandCloseActionCfg(MoveActionCfg): + """Shared configuration for actions that keep or move the gripper closed.""" hand_close_qpos: torch.Tensor | None = None """[hand_dof,] of float. Joint positions for closed hand state.""" @@ -55,6 +60,14 @@ class GraspActionCfg(MoveActionCfg): hand_control_part: str = "hand" """Name of the robot part that controls the hand joints.""" + +@configclass +class GraspActionCfg(HandCloseActionCfg): + """Shared configuration for actions that involve gripper open/close motions.""" + + hand_open_qpos: torch.Tensor | None = None + """[hand_dof,] of float. Joint positions for open hand state.""" + lift_height: float = 0.1 """Height (m) to lift the end-effector after the gripper phase.""" @@ -85,6 +98,12 @@ def __init__( self.arm_joint_ids = self.robot.get_joint_ids(name=self.cfg.control_part) self.dof = len(self.arm_joint_ids) + def _all_envs_success(self, is_success: bool | torch.Tensor) -> bool: + """Return true only when all environments report success.""" + if isinstance(is_success, torch.Tensor): + return bool(torch.all(is_success).item()) + return bool(is_success) + def _resolve_pose_target( self, target: Union[ObjectSemantics, torch.Tensor], @@ -209,7 +228,7 @@ def _plan_arm_trajectory( is_success, qpos = self.robot.compute_ik( pose=xpos_traj[:, j], name=self.cfg.control_part, joint_seed=qpos_seed ) - if not is_success: + if not self._all_envs_success(is_success): logger.log_warning( f"Failed to compute IK for target state {j} in some environments. " "The resulting trajectory may be invalid." @@ -231,11 +250,30 @@ def _interpolate_hand_qpos( n_waypoints: int, ) -> torch.Tensor: """Interpolate hand joint positions between two gripper states.""" - weights = torch.linspace(0, 1, steps=n_waypoints, device=self.device) - hand_qpos_list = [ - torch.lerp(start_hand_qpos, end_hand_qpos, weight) for weight in weights - ] - return torch.stack(hand_qpos_list, dim=0) + is_unbatched = start_hand_qpos.dim() == 1 and end_hand_qpos.dim() == 1 + start_hand_qpos = start_hand_qpos.to(self.device) + end_hand_qpos = end_hand_qpos.to(self.device) + + if start_hand_qpos.dim() == 1: + start_hand_qpos = start_hand_qpos.unsqueeze(0) + if end_hand_qpos.dim() == 1: + end_hand_qpos = end_hand_qpos.unsqueeze(0) + + weights = torch.linspace( + 0, + 1, + steps=n_waypoints, + device=self.device, + dtype=start_hand_qpos.dtype, + ) + result = torch.lerp( + start_hand_qpos.unsqueeze(1), + end_hand_qpos.unsqueeze(1), + weights[None, :, None], + ) + if is_unbatched: + return result.squeeze(0) + return result def execute( self, @@ -262,9 +300,7 @@ def execute( # TODO: warning and fallback if no valid grasp pose found if not is_success: - logger.log_warning( - "Failed to resolve grasp pose, using default approach pose" - ) + logger.log_warning("Failed to resolve move target pose.") return False, torch.empty(0), self.arm_joint_ids target_states_list = [ @@ -283,6 +319,48 @@ def validate(self, target, start_qpos=None, **kwargs): return True +class _HandCloseAction(MoveAction): + """Internal base for actions that keep the gripper closed.""" + + def __init__( + self, + motion_generator: MotionGenerator, + cfg: HandCloseActionCfg, + *, + cfg_name: str, + ): + super().__init__(motion_generator, cfg=cfg) + self._held_object_state: HeldObjectState | None = None + if self.cfg.hand_close_qpos is None: + logger.log_error(f"hand_close_qpos must be specified in {cfg_name}") + self.hand_close_qpos = self.cfg.hand_close_qpos.to(self.device) + + self.hand_joint_ids = self.robot.get_joint_ids(name=self.cfg.hand_control_part) + self.joint_ids = self.arm_joint_ids + self.hand_joint_ids + self.arm_dof = len(self.arm_joint_ids) + self.dof = len(self.joint_ids) + + def _expand_hand_qpos(self, hand_qpos: torch.Tensor) -> torch.Tensor: + """Resolve hand qpos to batched shape ``(n_envs, hand_dof)``.""" + hand_dof = len(self.hand_joint_ids) + hand_qpos = hand_qpos.to(device=self.device, dtype=torch.float32) + if hand_qpos.shape == (hand_dof,): + return hand_qpos.unsqueeze(0).repeat(self.n_envs, 1) + if hand_qpos.shape == (self.n_envs, hand_dof): + return hand_qpos + logger.log_error( + f"hand_qpos must have shape ({hand_dof},) or " + f"({self.n_envs}, {hand_dof}), but got {hand_qpos.shape}", + ValueError, + ) + + def _repeat_hand_qpos( + self, hand_qpos: torch.Tensor, n_waypoints: int + ) -> torch.Tensor: + """Repeat hand qpos across trajectory waypoints.""" + return self._expand_hand_qpos(hand_qpos).unsqueeze(1).repeat(1, n_waypoints, 1) + + @configclass class PickUpActionCfg(GraspActionCfg): name: str = "pick_up" @@ -298,6 +376,8 @@ class PickUpActionCfg(GraspActionCfg): class PickUpAction(MoveAction): + updates_held_object_state = True + def __init__( self, motion_generator: MotionGenerator, @@ -312,7 +392,7 @@ def __init__( super().__init__( motion_generator, cfg=cfg if cfg is not None else PickUpActionCfg() ) - self.cfg = cfg + self.cfg = cfg if cfg is not None else self.cfg self.approach_direction = self.cfg.approach_direction.to(self.device) if self.cfg.hand_open_qpos is None: logger.log_error("hand_open_qpos must be specified in PickUpActionCfg") @@ -325,6 +405,7 @@ def __init__( self.joint_ids = self.arm_joint_ids + self.hand_joint_ids self.arm_dof = len(self.arm_joint_ids) self.dof = len(self.joint_ids) + self._held_object_state: HeldObjectState | None = None def execute( self, @@ -346,20 +427,31 @@ def execute( """ # Resolve grasp pose - if isinstance(target, ObjectSemantics): - is_success, grasp_xpos = self._resolve_grasp_pose(target) + self._held_object_state = None + target_semantics = target if isinstance(target, ObjectSemantics) else None + if target_semantics is not None: + is_success, grasp_xpos = self._resolve_grasp_pose(target_semantics) else: is_success, grasp_xpos = self._resolve_pose_target( target, action_name=self.__class__.__name__ ) # TODO: warning and fallback if no valid grasp pose found - if not is_success: + if not self._all_envs_success(is_success): logger.log_warning( "Failed to resolve grasp pose, using default approach pose" ) return False, torch.empty(0), self.joint_ids + if target_semantics is not None: + obj_poses = target_semantics.entity.get_local_pose(to_matrix=True) + object_to_eef = torch.bmm(pose_inv(obj_poses), grasp_xpos) + self._held_object_state = HeldObjectState( + semantics=target_semantics, + object_to_eef=object_to_eef, + grasp_xpos=grasp_xpos, + ) + # Compute pre-grasp pose # TODO: only for parallel gripper, approach in negative grasp z direction grasp_z = grasp_xpos[:, :3, 2] @@ -459,7 +551,7 @@ def execute( def _resolve_grasp_pose( self, semantics: ObjectSemantics - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor]: if not isinstance(semantics.affordance, AntipodalAffordance): logger.log_error( "Grasp pose affordance must be of type AntipodalAffordance" @@ -516,6 +608,142 @@ def validate(self, target, start_qpos=None, **kwargs): # TODO: implement proper validation logic for pick up action return True + def get_held_object_state(self) -> HeldObjectState | None: + """Return the held-object state produced by the latest successful pickup.""" + return self._held_object_state + + +@configclass +class MoveObjectActionCfg(HandCloseActionCfg): + name: str = "move_object" + """Name of the action, used for identification and logging.""" + + +class MoveObjectAction(_HandCloseAction): + updates_held_object_state = True + + def __init__( + self, + motion_generator: MotionGenerator, + cfg: MoveObjectActionCfg | None = None, + ): + """ + Initialize the atomic action. + Args: + motion_generator: The motion generator instance to use for planning. + cfg: Configuration for the action. + """ + super().__init__( + motion_generator, + cfg=cfg if cfg is not None else MoveObjectActionCfg(), + cfg_name="MoveObjectActionCfg", + ) + + def _resolve_move_object_target( + self, + target: MoveObjectTarget, + action_context: dict | None = None, + held_object_state: HeldObjectState | None = None, + ) -> tuple[bool, torch.Tensor, HeldObjectState]: + """Resolve an object target pose into an end-effector target pose.""" + if not isinstance(target, MoveObjectTarget): + logger.log_error( + "MoveObjectAction target must be a MoveObjectTarget.", + TypeError, + ) + + held_state = held_object_state + if held_state is None and action_context is not None: + held_state = action_context.get("held_object_state") + if held_state is None: + logger.log_error( + "MoveObjectTarget requires a HeldObjectState from a prior PickUpAction.", + ValueError, + ) + + object_target_pose = target.object_target_pose.to( + device=self.device, dtype=torch.float32 + ) + if object_target_pose.shape == (4, 4): + object_target_pose = object_target_pose.unsqueeze(0).repeat( + self.n_envs, 1, 1 + ) + if object_target_pose.shape != (self.n_envs, 4, 4): + logger.log_error( + f"object_target_pose must have shape (4, 4) or " + f"({self.n_envs}, 4, 4), but got {object_target_pose.shape}", + ValueError, + ) + + object_to_eef = held_state.object_to_eef.to( + device=self.device, dtype=torch.float32 + ) + if object_to_eef.shape == (4, 4): + object_to_eef = object_to_eef.unsqueeze(0).repeat(self.n_envs, 1, 1) + if object_to_eef.shape != (self.n_envs, 4, 4): + logger.log_error( + f"object_to_eef must have shape (4, 4) or " + f"({self.n_envs}, 4, 4), but got {object_to_eef.shape}", + ValueError, + ) + + move_object_xpos = torch.bmm(object_target_pose, object_to_eef) + return True, move_object_xpos, held_state + + def execute( + self, + target: MoveObjectTarget, + start_qpos: Optional[torch.Tensor] = None, + **kwargs, + ) -> tuple[bool, torch.Tensor, list[float]]: + """Move the held object to a target object pose and keep grasping it.""" + is_success, move_object_xpos, held_state = self._resolve_move_object_target( + target, + action_context=kwargs.get("action_context"), + held_object_state=kwargs.get("held_object_state"), + ) + start_qpos = self._resolve_start_qpos(start_qpos, self.arm_dof) + self._held_object_state = held_state + + if not is_success: + logger.log_warning("Failed to resolve move_object target pose.") + return False, torch.empty(0), self.joint_ids + + target_states_list = [ + [ + PlanState(xpos=move_object_xpos[i], move_type=MoveType.EEF_MOVE), + ] + for i in range(self.n_envs) + ] + trajectory = torch.zeros( + size=(self.n_envs, self.cfg.sample_interval, self.dof), + dtype=torch.float32, + device=self.device, + ) + is_success, plan_traj = self._plan_arm_trajectory( + target_states_list, + start_qpos, + self.cfg.sample_interval, + self.arm_dof, + ) + if not is_success: + logger.log_warning("Failed to plan move_object trajectory.") + return False, trajectory, self.joint_ids + trajectory[:, :, : self.arm_dof] = plan_traj + trajectory[:, :, self.arm_dof :] = self._repeat_hand_qpos( + self.hand_close_qpos, + self.cfg.sample_interval + ) + return True, trajectory, self.joint_ids + + def get_held_object_state(self) -> HeldObjectState | None: + """Return the held-object state after moving the object.""" + return self._held_object_state + + def validate(self, target, start_qpos=None, **kwargs): + # TODO: implement proper validation logic for move object action + return True + @configclass class PlaceActionCfg(GraspActionCfg): @@ -524,6 +752,8 @@ class PlaceActionCfg(GraspActionCfg): class PlaceAction(MoveAction): + updates_held_object_state = True + def __init__( self, motion_generator: MotionGenerator, @@ -538,7 +768,7 @@ def __init__( super().__init__( motion_generator, cfg=cfg if cfg is not None else PlaceActionCfg() ) - self.cfg = cfg + self.cfg = cfg if cfg is not None else self.cfg if self.cfg.hand_open_qpos is None: logger.log_error("hand_open_qpos must be specified in PlaceActionCfg") if self.cfg.hand_close_qpos is None: @@ -669,3 +899,7 @@ def execute( def validate(self, target, start_qpos=None, **kwargs): # TODO: implement proper validation logic for pick up action return True + + def get_held_object_state(self) -> HeldObjectState | None: + """Return None after place releases the held object.""" + return None diff --git a/embodichain/lab/sim/atomic_actions/core.py b/embodichain/lab/sim/atomic_actions/core.py index a30698cb..cac89de8 100644 --- a/embodichain/lab/sim/atomic_actions/core.py +++ b/embodichain/lab/sim/atomic_actions/core.py @@ -19,7 +19,7 @@ import torch from abc import ABC, abstractmethod from dataclasses import dataclass, field -from typing import Any, Dict, List, Optional, Union, TYPE_CHECKING +from typing import Any, ClassVar, Dict, List, Optional, Union, TYPE_CHECKING from embodichain.lab.sim.planners import PlanResult, PlanState, MoveType from embodichain.utils import configclass @@ -320,6 +320,28 @@ def __post_init__(self) -> None: self.affordance.geometry = self.geometry +@dataclass +class HeldObjectState: + """State shared by actions while an object is held by the robot.""" + + semantics: ObjectSemantics + """Semantic object currently held by the gripper.""" + + object_to_eef: torch.Tensor + """Batched transform from object frame to end-effector frame, shape [B, 4, 4].""" + + grasp_xpos: torch.Tensor + """Batched end-effector grasp pose selected during pickup, shape [B, 4, 4].""" + + +@dataclass +class MoveObjectTarget: + """Object-centric target for moving a held object without releasing it.""" + + object_target_pose: torch.Tensor + """Target object pose, shape [4, 4] or [B, 4, 4].""" + + # ============================================================================= # ActionCfg and AtomicAction # ============================================================================= @@ -353,6 +375,9 @@ class AtomicAction(ABC): the existing motion planning infrastructure. """ + updates_held_object_state: ClassVar[bool] = False + """Whether the engine should read held-object state after this action.""" + def __init__( self, motion_generator: MotionGenerator, @@ -370,6 +395,10 @@ def __init__( self.control_part = cfg.control_part self.device = self.robot.device + def get_held_object_state(self) -> HeldObjectState | None: + """Return held-object state after execution if this action updates it.""" + return None + @abstractmethod def execute( self, diff --git a/embodichain/lab/sim/atomic_actions/engine.py b/embodichain/lab/sim/atomic_actions/engine.py index 15b868a8..591abe41 100644 --- a/embodichain/lab/sim/atomic_actions/engine.py +++ b/embodichain/lab/sim/atomic_actions/engine.py @@ -21,7 +21,7 @@ from embodichain.lab.sim.planners import PlanResult from embodichain.utils import logger -from .core import AtomicAction, ObjectSemantics, ActionCfg +from .core import AtomicAction, ObjectSemantics, ActionCfg, MoveObjectTarget if TYPE_CHECKING: from embodichain.lab.sim.planners import MotionGenerator @@ -178,6 +178,7 @@ def __init__( # Semantic analyzer for object understanding self._semantic_analyzer = SemanticAnalyzer() + self._action_context: Dict[str, Any] = {} # Initialize default actions self._actions: Dict[str, AtomicAction] = self._init_actions(actions_cfg_list) @@ -186,11 +187,12 @@ def _init_actions( self, actions_cfg_list: Optional[List[ActionCfg]] = None ) -> Dict[str, "AtomicAction"]: actions: Dict[str, AtomicAction] = {} - from .actions import MoveAction, PickUpAction, PlaceAction + from .actions import MoveAction, MoveObjectAction, PickUpAction, PlaceAction builtin_action_map: Dict[str, Type[AtomicAction]] = { "move": MoveAction, "pick_up": PickUpAction, + "move_object": MoveObjectAction, "place": PlaceAction, } if actions_cfg_list is not None: @@ -207,7 +209,15 @@ def _init_actions( def execute_static( self, - target_list: List[Union[torch.Tensor, str, ObjectSemantics, Dict[str, Any]]], + target_list: List[ + Union[ + torch.Tensor, + str, + ObjectSemantics, + MoveObjectTarget, + Dict[str, Any], + ] + ], ) -> tuple[bool, torch.Tensor]: """Execute a sequence of actions to target poses. @@ -219,6 +229,7 @@ def execute_static( logger.log_error( f"Length of target_list ({len(target_list)}) must match number of actions ({len(action_names)})." ) + self._action_context.clear() start_qpos = self.motion_generator.robot.get_qpos() n_envs = start_qpos.shape[0] all_dof = self.motion_generator.robot.dof @@ -233,10 +244,21 @@ def execute_static( arm_joint_ids = self.motion_generator.robot.get_joint_ids(name=control_part) start_qpos_part = start_qpos[:, arm_joint_ids] is_success, traj, joint_ids = atom_action.execute( - target=target, start_qpos=start_qpos_part + target=target, + start_qpos=start_qpos_part, + action_context=self._action_context, + held_object_state=self._action_context.get("held_object_state"), ) if not is_success: return False, all_trajectory + + if atom_action.updates_held_object_state: + held_state = atom_action.get_held_object_state() + if held_state is None: + self._action_context.pop("held_object_state", None) + else: + self._action_context["held_object_state"] = held_state + n_waypoints = traj.shape[1] traj_full = torch.zeros( @@ -254,7 +276,9 @@ def execute_static( def validate( self, action_name: str, - target: Union[torch.Tensor, str, ObjectSemantics, Dict[str, Any]], + target: Union[ + torch.Tensor, str, ObjectSemantics, MoveObjectTarget, Dict[str, Any] + ], **kwargs, ) -> bool: """Validate if a named action is feasible without executing.""" @@ -268,8 +292,10 @@ def validate( def _resolve_target( self, - target: Union[torch.Tensor, str, ObjectSemantics, Dict[str, Any]], - ) -> Union[torch.Tensor, ObjectSemantics]: + target: Union[ + torch.Tensor, str, ObjectSemantics, MoveObjectTarget, Dict[str, Any] + ], + ) -> Union[torch.Tensor, ObjectSemantics, MoveObjectTarget]: """Resolve user target input into tensor pose or ObjectSemantics. Supports the convenience dict format in ``execute`` and ``validate``. @@ -277,13 +303,21 @@ def _resolve_target( if isinstance(target, torch.Tensor): return target - if isinstance(target, ObjectSemantics): + if isinstance(target, (ObjectSemantics, MoveObjectTarget)): return target if isinstance(target, str): return self._semantic_analyzer.analyze(target) if isinstance(target, dict): + if "object_target_pose" in target: + object_target_pose = target["object_target_pose"] + if not isinstance(object_target_pose, torch.Tensor): + raise TypeError( + "target['object_target_pose'] must be a torch.Tensor" + ) + return MoveObjectTarget(object_target_pose=object_target_pose) + if "pose" in target: pose = target["pose"] if not isinstance(pose, torch.Tensor): @@ -328,7 +362,8 @@ def _resolve_target( return semantics raise TypeError( - "target must be torch.Tensor, str, ObjectSemantics, or Dict[str, Any]" + "target must be torch.Tensor, str, ObjectSemantics, MoveObjectTarget, " + "or Dict[str, Any]" ) def get_semantic_analyzer(self) -> SemanticAnalyzer: diff --git a/embodichain/lab/sim/solvers/pytorch_solver.py b/embodichain/lab/sim/solvers/pytorch_solver.py index c0fcf465..277eaf87 100644 --- a/embodichain/lab/sim/solvers/pytorch_solver.py +++ b/embodichain/lab/sim/solvers/pytorch_solver.py @@ -62,7 +62,7 @@ class PytorchSolverCfg(SolverCfg): is_only_position_constraint: bool = False """Flag to indicate whether the solver should only consider position constraints.""" - num_samples: int = 5 + num_samples: int = 30 """Number of samples to generate different joint seeds for IK iterations. A higher number of samples increases the chances of finding a valid solution diff --git a/scripts/tutorials/atomic_action/move_object_atomic_actions.py b/scripts/tutorials/atomic_action/move_object_atomic_actions.py new file mode 100644 index 00000000..abce7271 --- /dev/null +++ b/scripts/tutorials/atomic_action/move_object_atomic_actions.py @@ -0,0 +1,434 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +"""Demonstrate moving a held object to an object-centric target pose.""" + +from __future__ import annotations + +import argparse +import sys +import time +from pathlib import Path + +_REPO_ROOT = Path(__file__).resolve().parents[3] +if str(_REPO_ROOT) not in sys.path: + sys.path.insert(0, str(_REPO_ROOT)) + +import torch + +from embodichain.data import get_data_path +from embodichain.lab.gym.utils.gym_utils import add_env_launcher_args_to_parser +from embodichain.lab.sim import SimulationManager, SimulationManagerCfg +from embodichain.lab.sim.atomic_actions import ( + AntipodalAffordance, + AtomicActionEngine, + MoveActionCfg, + MoveObjectActionCfg, + MoveObjectTarget, + ObjectSemantics, + PickUpActionCfg, +) +from embodichain.lab.sim.cfg import ( + JointDrivePropertiesCfg, + LightCfg, + RenderCfg, + RigidBodyAttributesCfg, + RigidObjectCfg, + RobotCfg, + URDFCfg, +) +from embodichain.lab.sim.objects import RigidObject, Robot +from embodichain.lab.sim.planners import MotionGenerator, MotionGenCfg, ToppraPlannerCfg +from embodichain.lab.sim.shapes import CubeCfg, MeshCfg +from embodichain.lab.sim.solvers import PytorchSolverCfg +from embodichain.toolkits.graspkit.pg_grasp.antipodal_generator import ( + AntipodalSamplerCfg, + GraspGeneratorCfg, +) +from embodichain.toolkits.graspkit.pg_grasp.gripper_collision_checker import ( + GripperCollisionCfg, +) +from embodichain.utils import logger + +GRIPPER_URDF_PATH = "DH_PGI_140_80/DH_PGI_140_80.urdf" +GRIPPER_HAND_JOINT_PATTERN = "GRIPPER_FINGER1_JOINT_1" +GRIPPER_MAX_OPEN_WIDTH = 0.080 +GRIPPER_FINGER_LENGTH = 0.088 +GRIPPER_ROOT_Z_WIDTH = 0.096 +GRIPPER_Y_THICKNESS = 0.040 +GRIPPER_TCP_Z = 0.15 + +BOTTLE_LABEL = "bottle" +BOTTLE_APPROACH_DIRECTION = (0.0, 0.0, -1.0) +BOTTLE_MIN_HAND_CLOSE_QPOS = 0.024 + +MOVE_SAMPLE_INTERVAL = 60 +PICK_SAMPLE_INTERVAL = 120 +MOVE_OBJECT_SAMPLE_INTERVAL = 120 +HAND_INTERP_STEPS = 12 +POST_TRAJECTORY_STEPS = 240 +TABLE_SIZE = [1.0, 1.4, 0.05] +TABLE_TOP_Z = -0.045 + + +def parse_arguments() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Demonstrate MoveObjectAction holding a bottle in the gripper." + ) + add_env_launcher_args_to_parser(parser) + parser.add_argument( + "--n_sample", + type=int, + default=10000, + help="Number of samples for antipodal grasp generation.", + ) + parser.add_argument( + "--force_reannotate", + action="store_true", + help="Force grasp region re-annotation instead of using cached data.", + ) + parser.add_argument( + "--auto_play", + action="store_true", + help="Run the viewer demo without waiting for keyboard input.", + ) + parser.add_argument( + "--debug_state", + action="store_true", + help="Log bottle pose during replay.", + ) + return parser.parse_args() + + +def initialize_simulation(args: argparse.Namespace) -> SimulationManager: + cfg = SimulationManagerCfg( + headless=True, + sim_device=args.device, + render_cfg=RenderCfg(renderer=args.renderer), + physics_dt=1.0 / 100.0, + arena_space=2.5, + ) + sim = SimulationManager(cfg) + sim.add_light( + cfg=LightCfg( + uid="main_light", + color=(0.6, 0.6, 0.6), + intensity=30.0, + init_pos=(1.0, 0.0, 3.0), + ) + ) + return sim + + +def create_robot(sim: SimulationManager, position=(0.0, 0.0, 0.0)) -> Robot: + ur5_urdf_path = get_data_path("UniversalRobots/UR5/UR5.urdf") + gripper_urdf_path = get_data_path(GRIPPER_URDF_PATH) + + cfg = RobotCfg( + uid="UR5", + urdf_cfg=URDFCfg( + components=[ + {"component_type": "arm", "urdf_path": ur5_urdf_path}, + {"component_type": "hand", "urdf_path": gripper_urdf_path}, + ] + ), + drive_pros=JointDrivePropertiesCfg( + stiffness={"JOINT[0-9]": 1e4, GRIPPER_HAND_JOINT_PATTERN: 1e3}, + damping={"JOINT[0-9]": 1e3, GRIPPER_HAND_JOINT_PATTERN: 1e2}, + max_effort={"JOINT[0-9]": 1e5, GRIPPER_HAND_JOINT_PATTERN: 1e4}, + drive_type="force", + ), + control_parts={ + "arm": ["JOINT[0-9]"], + "hand": [GRIPPER_HAND_JOINT_PATTERN], + }, + solver_cfg={ + "arm": PytorchSolverCfg( + end_link_name="ee_link", + root_link_name="base_link", + tcp=[ + [1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 1.0, GRIPPER_TCP_Z], + [0.0, 0.0, 0.0, 1.0], + ], + ) + }, + init_qpos=[0.0, -1.57, 1.57, -1.57, -1.57, 0.0, 0.0, 0.0], + init_pos=position, + ) + return sim.add_robot(cfg=cfg) + + +def create_table(sim: SimulationManager) -> RigidObject: + cfg = RigidObjectCfg( + uid="table", + shape=CubeCfg(size=TABLE_SIZE), + body_type="static", + attrs=RigidBodyAttributesCfg( + dynamic_friction=0.8, + static_friction=0.9, + ), + init_pos=[-0.30, 0.10, TABLE_TOP_Z - 0.5 * TABLE_SIZE[2]], + ) + return sim.add_rigid_object(cfg=cfg) + + +def create_fallen_bottle(sim: SimulationManager) -> RigidObject: + bottle_scale = 0.0008 + cfg = RigidObjectCfg( + uid="bottle", + shape=MeshCfg(fpath=get_data_path("ScannedBottle/yibao.ply")), + attrs=RigidBodyAttributesCfg( + mass=0.02, + dynamic_friction=0.97, + static_friction=0.99, + ), + max_convex_hull_num=16, + init_pos=[-0.4294, -0.0825, -0.0997], + init_rot=[90.0, 45.0, 0.0], + body_scale=(bottle_scale, bottle_scale, bottle_scale), + ) + return sim.add_rigid_object(cfg=cfg) + + +def settle_object(sim: SimulationManager, obj: RigidObject, step: int = 5) -> None: + if sim.device.type == "cuda": + sim.init_gpu_physics() + obj.reset() + sim.update(step=step) + obj.clear_dynamics() + + +def build_grasp_generator_cfg(args: argparse.Namespace) -> GraspGeneratorCfg: + return GraspGeneratorCfg( + viser_port=11801, + antipodal_sampler_cfg=AntipodalSamplerCfg( + n_sample=args.n_sample, + max_length=GRIPPER_MAX_OPEN_WIDTH, + min_length=0.003, + ), + is_partial_annotate=False, + is_filter_ground_collision=False, + ) + + +def build_gripper_collision_cfg() -> GripperCollisionCfg: + return GripperCollisionCfg( + max_open_length=GRIPPER_MAX_OPEN_WIDTH, + finger_length=GRIPPER_FINGER_LENGTH, + y_thickness=GRIPPER_Y_THICKNESS, + root_z_width=GRIPPER_ROOT_Z_WIDTH, + open_check_margin=0.002, + point_sample_dense=0.012, + ) + + +def create_object_semantics( + obj: RigidObject, args: argparse.Namespace +) -> ObjectSemantics: + return ObjectSemantics( + label=BOTTLE_LABEL, + geometry={ + "mesh_vertices": obj.get_vertices(env_ids=[0], scale=True)[0], + "mesh_triangles": obj.get_triangles(env_ids=[0])[0], + }, + affordance=AntipodalAffordance( + object_label=BOTTLE_LABEL, + force_reannotate=args.force_reannotate, + custom_config={ + "gripper_collision_cfg": build_gripper_collision_cfg(), + "generator_cfg": build_grasp_generator_cfg(args), + }, + ), + entity=obj, + ) + + +def get_hand_open_close_qpos( + robot: Robot, device: torch.device +) -> tuple[torch.Tensor, torch.Tensor]: + hand_limits = robot.get_qpos_limits(name="hand")[0].to( + device=device, dtype=torch.float32 + ) + hand_open = hand_limits[:, 0] + hand_close_limit = hand_limits[:, 1] + hand_close = torch.minimum( + hand_close_limit, + torch.full_like(hand_close_limit, BOTTLE_MIN_HAND_CLOSE_QPOS), + ) + return hand_open, hand_close + + +def make_top_down_eef_pose(position: torch.Tensor) -> torch.Tensor: + pose = torch.eye(4, dtype=torch.float32, device=position.device) + pose[:3, :3] = torch.tensor( + [ + [-0.0539, -0.9985, -0.0022], + [-0.9977, 0.0540, -0.0401], + [0.0401, 0.0000, -0.9992], + ], + dtype=torch.float32, + device=position.device, + ) + pose[:3, 3] = position + return pose + + +def make_upright_object_pose(device: torch.device) -> torch.Tensor: + pose = torch.eye(4, dtype=torch.float32, device=device) + pose[:3, :3] = torch.tensor( + [ + [1.0, 0.0, 0.0], + [0.0, -1.0, 0.0], + [0.0, 0.0, -1.0], + ], + dtype=torch.float32, + device=device, + ) + pose[:3, 3] = torch.tensor([0.28, -0.2, 0.22], dtype=torch.float32, device=device) + return pose + + +def compute_pick_close_end_step() -> int: + motion_waypoints = PICK_SAMPLE_INTERVAL - HAND_INTERP_STEPS + n_approach = int(round(motion_waypoints) * 0.6) + return MOVE_SAMPLE_INTERVAL + n_approach + HAND_INTERP_STEPS + + +def format_tensor(tensor: torch.Tensor) -> str: + rounded = (tensor.detach().cpu() * 10000.0).round() / 10000.0 + return str(rounded.tolist()) + + +def log_object_state(obj: RigidObject, label: str) -> None: + obj_pose = obj.get_local_pose(to_matrix=True) + logger.log_info( + f"{label}: pos={format_tensor(obj_pose[0, :3, 3])}, " + f"z_axis={format_tensor(obj_pose[0, :3, 2])}" + ) + + +def build_action_sequence( + hand_open: torch.Tensor, + hand_close: torch.Tensor, + device: torch.device, +) -> list: + move_cfg = MoveActionCfg( + control_part="arm", + sample_interval=MOVE_SAMPLE_INTERVAL, + ) + pickup_cfg = PickUpActionCfg( + control_part="arm", + hand_control_part="hand", + hand_open_qpos=hand_open, + hand_close_qpos=hand_close, + approach_direction=torch.tensor( + BOTTLE_APPROACH_DIRECTION, dtype=torch.float32, device=device + ), + pre_grasp_distance=0.15, + lift_height=0.16, + sample_interval=PICK_SAMPLE_INTERVAL, + hand_interp_steps=HAND_INTERP_STEPS, + ) + move_object_cfg = MoveObjectActionCfg( + control_part="arm", + hand_control_part="hand", + hand_close_qpos=hand_close, + sample_interval=MOVE_OBJECT_SAMPLE_INTERVAL, + ) + return [move_cfg, pickup_cfg, move_object_cfg] + + +def run_move_object_demo(args: argparse.Namespace) -> None: + sim = initialize_simulation(args) + robot = create_robot(sim) + create_table(sim) + obj = create_fallen_bottle(sim) + + settle_object(sim, obj, step=5) + semantics = create_object_semantics(obj, args) + motion_gen = MotionGenerator( + cfg=MotionGenCfg(planner_cfg=ToppraPlannerCfg(robot_uid=robot.uid)) + ) + hand_open, hand_close = get_hand_open_close_qpos(robot, sim.device) + action_cfgs = build_action_sequence(hand_open, hand_close, sim.device) + atomic_engine = AtomicActionEngine( + motion_generator=motion_gen, + actions_cfg_list=action_cfgs, + ) + + sim.open_window() + if not args.auto_play: + input("Inspect the fallen bottle, then press Enter to plan...") + + obj_pose = obj.get_local_pose(to_matrix=True) + move_position = obj_pose[0, :3, 3].clone() + move_position[2] = 0.36 + move_target = make_top_down_eef_pose(move_position) + move_object_target = MoveObjectTarget( + object_target_pose=make_upright_object_pose(sim.device) + ) + + logger.log_info("Planning move -> pick_up -> move_object") + start_time = time.time() + is_success, traj = atomic_engine.execute_static( + target_list=[move_target, semantics, move_object_target] + ) + cost_time = time.time() - start_time + logger.log_info(f"Plan trajectory cost time: {cost_time:.2f} seconds") + if not is_success: + logger.log_warning("Failed to plan move_object demo trajectory.") + return + + if not args.auto_play: + input("Press Enter to replay the move_object demo...") + + post_grasp_clear_step = compute_pick_close_end_step() + should_clear_object_dynamics = True + log_stride = max(1, traj.shape[1] // 10) + for i in range(traj.shape[1]): + robot.set_qpos(traj[:, i, :]) + sim.update(step=4) + if should_clear_object_dynamics and i + 1 >= post_grasp_clear_step: + obj.clear_dynamics() + should_clear_object_dynamics = False + logger.log_info(f"Object dynamics cleared after grasp at step={i}") + if args.debug_state and (i % log_stride == 0 or i == traj.shape[1] - 1): + log_object_state(obj, f"replay step {i}/{traj.shape[1] - 1}") + time.sleep(1e-2) + + logger.log_info("MoveObjectAction keeps the bottle suspended in the gripper.") + + final_qpos = traj[:, -1, :] + for i in range(POST_TRAJECTORY_STEPS): + robot.set_qpos(final_qpos) + sim.update(step=2) + if args.debug_state and i % max(1, POST_TRAJECTORY_STEPS // 5) == 0: + log_object_state(obj, f"post step {i}/{POST_TRAJECTORY_STEPS - 1}") + time.sleep(1e-2) + + if not args.auto_play: + input("Press Enter to exit the simulation...") + + +def main() -> None: + args = parse_arguments() + run_move_object_demo(args) + + +if __name__ == "__main__": + main() diff --git a/scripts/tutorials/atomic_action/pickup_atomic_actions.py b/scripts/tutorials/atomic_action/pickup_atomic_actions.py new file mode 100644 index 00000000..64bcea61 --- /dev/null +++ b/scripts/tutorials/atomic_action/pickup_atomic_actions.py @@ -0,0 +1,499 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +"""Demonstrate PickUpAction on an upright object with configurable approach.""" + +from __future__ import annotations + +import argparse +import sys +import time +from pathlib import Path + +_REPO_ROOT = Path(__file__).resolve().parents[3] +if str(_REPO_ROOT) not in sys.path: + sys.path.insert(0, str(_REPO_ROOT)) + +import torch + +from embodichain.data import get_data_path +from embodichain.lab.gym.utils.gym_utils import add_env_launcher_args_to_parser +from embodichain.lab.sim import SimulationManager, SimulationManagerCfg +from embodichain.lab.sim.atomic_actions import ( + AntipodalAffordance, + AtomicActionEngine, + MoveActionCfg, + ObjectSemantics, + PickUpActionCfg, +) +from embodichain.lab.sim.cfg import ( + JointDrivePropertiesCfg, + LightCfg, + RenderCfg, + RigidBodyAttributesCfg, + RigidObjectCfg, + RobotCfg, + URDFCfg, +) +from embodichain.lab.sim.objects import RigidObject, Robot +from embodichain.lab.sim.planners import MotionGenerator, MotionGenCfg, ToppraPlannerCfg +from embodichain.lab.sim.shapes import CubeCfg, MeshCfg +from embodichain.lab.sim.solvers import PytorchSolverCfg +from embodichain.toolkits.graspkit.pg_grasp.antipodal_generator import ( + AntipodalSamplerCfg, + GraspGeneratorCfg, +) +from embodichain.toolkits.graspkit.pg_grasp.gripper_collision_checker import ( + GripperCollisionCfg, +) +from embodichain.utils import logger +from embodichain.utils.math import matrix_from_euler + +GRIPPER_URDF_PATH = "DH_PGI_140_80/DH_PGI_140_80.urdf" +GRIPPER_HAND_JOINT_PATTERN = "GRIPPER_FINGER1_JOINT_1" +GRIPPER_MAX_OPEN_WIDTH = 0.080 +GRIPPER_FINGER_LENGTH = 0.088 +GRIPPER_ROOT_Z_WIDTH = 0.096 +GRIPPER_Y_THICKNESS = 0.040 +GRIPPER_TCP_Z = 0.15 + +OBJECT_MIN_HAND_CLOSE_QPOS = 0.024 +OBJECT_XY = (-0.42, -0.08) +OBJECT_CLEARANCE = 0.0 + +OBJECT_PRESETS = { + "paper_cup": { + "label": "paper_cup", + "mesh_path": "PaperCup/paper_cup.ply", + "init_rot": (0.0, 0.0, 0.0), + "body_scale": (1.0, 1.0, 1.0), + "mass": 0.01, + }, + "coffee_cup": { + "label": "coffee_cup", + "mesh_path": "CoffeeCup/cup.ply", + "init_rot": (0.0, 0.0, -90.0), + "body_scale": (1.0, 1.0, 1.0), + "mass": 0.01, + }, + "bottle": { + "label": "bottle", + "mesh_path": "ScannedBottle/yibao.ply", + "init_rot": (180.0, 0.0, 0.0), + "body_scale": (0.0008, 0.0008, 0.0008), + "mass": 0.02, + }, +} + +MOVE_SAMPLE_INTERVAL = 60 +PICK_SAMPLE_INTERVAL = 120 +HAND_INTERP_STEPS = 12 +POST_TRAJECTORY_STEPS = 240 +TABLE_SIZE = [1.0, 1.4, 0.05] +TABLE_TOP_Z = -0.045 + +APPROACH_DIRECTIONS = { + "top": (0.0, 0.0, -1.0), + "side": (0.0, 1.0, 0.0), + "side_y": (0.0, -1.0, 0.0), +} + + +def parse_arguments() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Demonstrate PickUpAction on an upright object." + ) + add_env_launcher_args_to_parser(parser) + parser.add_argument( + "--object", + choices=sorted(OBJECT_PRESETS.keys()), + default="paper_cup", + help="Object preset to pick.", + ) + parser.add_argument( + "--n_sample", + type=int, + default=10000, + help="Number of samples for antipodal grasp generation.", + ) + parser.add_argument( + "--force_reannotate", + action="store_true", + help="Force grasp region re-annotation instead of using cached data.", + ) + parser.add_argument( + "--auto_play", + action="store_true", + help="Run the viewer demo without waiting for keyboard input.", + ) + parser.add_argument( + "--debug_state", + "--debug", + action="store_true", + help="Log object pose during replay.", + ) + parser.add_argument( + "--approach", + choices=["top", "side", "side_y", "custom"], + default="side", + help="Pick approach direction preset.", + ) + parser.add_argument( + "--custom_approach_direction", + type=float, + nargs=3, + default=None, + metavar=("X", "Y", "Z"), + help="World-frame approach direction used when --approach custom.", + ) + return parser.parse_args() + + +def initialize_simulation(args: argparse.Namespace) -> SimulationManager: + cfg = SimulationManagerCfg( + headless=True, + sim_device=args.device, + render_cfg=RenderCfg(renderer=args.renderer), + physics_dt=1.0 / 100.0, + arena_space=2.5, + ) + sim = SimulationManager(cfg) + sim.add_light( + cfg=LightCfg( + uid="main_light", + color=(0.6, 0.6, 0.6), + intensity=30.0, + init_pos=(1.0, 0.0, 3.0), + ) + ) + return sim + + +def create_robot(sim: SimulationManager, position=(0.0, 0.0, 0.0)) -> Robot: + ur5_urdf_path = get_data_path("UniversalRobots/UR5/UR5.urdf") + gripper_urdf_path = get_data_path(GRIPPER_URDF_PATH) + + cfg = RobotCfg( + uid="UR5", + urdf_cfg=URDFCfg( + components=[ + {"component_type": "arm", "urdf_path": ur5_urdf_path}, + {"component_type": "hand", "urdf_path": gripper_urdf_path}, + ] + ), + drive_pros=JointDrivePropertiesCfg( + stiffness={"JOINT[0-9]": 1e4, GRIPPER_HAND_JOINT_PATTERN: 1e3}, + damping={"JOINT[0-9]": 1e3, GRIPPER_HAND_JOINT_PATTERN: 1e2}, + max_effort={"JOINT[0-9]": 1e5, GRIPPER_HAND_JOINT_PATTERN: 1e4}, + drive_type="force", + ), + control_parts={ + "arm": ["JOINT[0-9]"], + "hand": [GRIPPER_HAND_JOINT_PATTERN], + }, + solver_cfg={ + "arm": PytorchSolverCfg( + end_link_name="ee_link", + root_link_name="base_link", + tcp=[ + [1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 1.0, GRIPPER_TCP_Z], + [0.0, 0.0, 0.0, 1.0], + ], + ) + }, + init_qpos=[0.0, -1.57, 1.57, -1.57, -1.57, 0.0, 0.0, 0.0], + init_pos=position, + ) + return sim.add_robot(cfg=cfg) + + +def create_table(sim: SimulationManager) -> RigidObject: + cfg = RigidObjectCfg( + uid="table", + shape=CubeCfg(size=TABLE_SIZE), + body_type="static", + attrs=RigidBodyAttributesCfg( + dynamic_friction=0.8, + static_friction=0.9, + ), + init_pos=[-0.30, 0.10, TABLE_TOP_Z - 0.5 * TABLE_SIZE[2]], + ) + return sim.add_rigid_object(cfg=cfg) + + +def create_pick_object(sim: SimulationManager, object_name: str) -> RigidObject: + preset = OBJECT_PRESETS[object_name] + cfg = RigidObjectCfg( + uid=preset["label"], + shape=MeshCfg(fpath=get_data_path(preset["mesh_path"])), + attrs=RigidBodyAttributesCfg( + mass=preset["mass"], + dynamic_friction=0.97, + static_friction=0.99, + ), + max_convex_hull_num=16, + init_pos=[OBJECT_XY[0], OBJECT_XY[1], 0.0], + init_rot=preset["init_rot"], + body_scale=preset["body_scale"], + ) + obj = sim.add_rigid_object(cfg=cfg) + obj.cfg.init_pos = _compute_tabletop_init_pos(obj, cfg.init_rot) + obj.reset() + return obj + + +def _compute_tabletop_init_pos( + obj: RigidObject, init_rot: tuple[float, float, float] +) -> tuple[float, float, float]: + vertices = obj.get_vertices(env_ids=[0], scale=True)[0] + rot = torch.as_tensor(init_rot, dtype=torch.float32, device=vertices.device) + rot = rot.unsqueeze(0) * torch.pi / 180.0 + upright_rot = matrix_from_euler(rot, "XYZ")[0] + rotated_vertices = vertices @ upright_rot.T + bottom_z = rotated_vertices[:, 2].min().item() + z = TABLE_TOP_Z + OBJECT_CLEARANCE - bottom_z + return (OBJECT_XY[0], OBJECT_XY[1], z) + + +def settle_object(sim: SimulationManager, obj: RigidObject, step: int = 5) -> None: + if sim.device.type == "cuda": + sim.init_gpu_physics() + obj.reset() + sim.update(step=step) + obj.clear_dynamics() + + +def build_grasp_generator_cfg(args: argparse.Namespace) -> GraspGeneratorCfg: + return GraspGeneratorCfg( + viser_port=11801, + antipodal_sampler_cfg=AntipodalSamplerCfg( + n_sample=args.n_sample, + max_length=GRIPPER_MAX_OPEN_WIDTH, + min_length=0.003, + ), + is_partial_annotate=False, + is_filter_ground_collision=False, + ) + + +def build_gripper_collision_cfg() -> GripperCollisionCfg: + return GripperCollisionCfg( + max_open_length=GRIPPER_MAX_OPEN_WIDTH, + finger_length=GRIPPER_FINGER_LENGTH, + y_thickness=GRIPPER_Y_THICKNESS, + root_z_width=GRIPPER_ROOT_Z_WIDTH, + open_check_margin=0.002, + point_sample_dense=0.012, + ) + + +def create_object_semantics( + obj: RigidObject, args: argparse.Namespace +) -> ObjectSemantics: + label = OBJECT_PRESETS[args.object]["label"] + return ObjectSemantics( + label=label, + geometry={ + "mesh_vertices": obj.get_vertices(env_ids=[0], scale=True)[0], + "mesh_triangles": obj.get_triangles(env_ids=[0])[0], + }, + affordance=AntipodalAffordance( + object_label=label, + force_reannotate=args.force_reannotate, + custom_config={ + "gripper_collision_cfg": build_gripper_collision_cfg(), + "generator_cfg": build_grasp_generator_cfg(args), + }, + ), + entity=obj, + ) + + +def get_hand_open_close_qpos( + robot: Robot, device: torch.device +) -> tuple[torch.Tensor, torch.Tensor]: + hand_limits = robot.get_qpos_limits(name="hand")[0].to( + device=device, dtype=torch.float32 + ) + hand_open = hand_limits[:, 0] + hand_close_limit = hand_limits[:, 1] + hand_close = torch.minimum( + hand_close_limit, + torch.full_like(hand_close_limit, OBJECT_MIN_HAND_CLOSE_QPOS), + ) + return hand_open, hand_close + + +def resolve_approach_direction( + args: argparse.Namespace, device: torch.device +) -> torch.Tensor: + if args.approach == "custom": + if args.custom_approach_direction is None: + raise ValueError( + "--custom_approach_direction is required when --approach custom." + ) + direction = args.custom_approach_direction + else: + direction = APPROACH_DIRECTIONS[args.approach] + + approach_direction = torch.tensor(direction, dtype=torch.float32, device=device) + norm = torch.linalg.norm(approach_direction) + if norm < 1e-6: + raise ValueError("approach_direction must be non-zero.") + return approach_direction / norm + + +def make_top_down_eef_pose(position: torch.Tensor) -> torch.Tensor: + pose = torch.eye(4, dtype=torch.float32, device=position.device) + pose[:3, :3] = torch.tensor( + [ + [-0.0539, -0.9985, -0.0022], + [-0.9977, 0.0540, -0.0401], + [0.0401, 0.0000, -0.9992], + ], + dtype=torch.float32, + device=position.device, + ) + pose[:3, 3] = position + return pose + + +def compute_pick_close_end_step() -> int: + motion_waypoints = PICK_SAMPLE_INTERVAL - HAND_INTERP_STEPS + n_approach = int(round(motion_waypoints) * 0.6) + return MOVE_SAMPLE_INTERVAL + n_approach + HAND_INTERP_STEPS + + +def format_tensor(tensor: torch.Tensor) -> str: + rounded = (tensor.detach().cpu() * 10000.0).round() / 10000.0 + return str(rounded.tolist()) + + +def log_object_state(obj: RigidObject, label: str) -> None: + obj_pose = obj.get_local_pose(to_matrix=True) + logger.log_info( + f"{label}: pos={format_tensor(obj_pose[0, :3, 3])}, " + f"z_axis={format_tensor(obj_pose[0, :3, 2])}" + ) + + +def build_action_sequence( + hand_open: torch.Tensor, + hand_close: torch.Tensor, + approach_direction: torch.Tensor, +) -> list: + move_cfg = MoveActionCfg( + control_part="arm", + sample_interval=MOVE_SAMPLE_INTERVAL, + ) + pickup_cfg = PickUpActionCfg( + control_part="arm", + hand_control_part="hand", + hand_open_qpos=hand_open, + hand_close_qpos=hand_close, + approach_direction=approach_direction, + pre_grasp_distance=0.15, + lift_height=0.16, + sample_interval=PICK_SAMPLE_INTERVAL, + hand_interp_steps=HAND_INTERP_STEPS, + ) + return [move_cfg, pickup_cfg] + + +def run_pickup_demo(args: argparse.Namespace) -> None: + sim = initialize_simulation(args) + robot = create_robot(sim) + create_table(sim) + obj = create_pick_object(sim, args.object) + + settle_object(sim, obj, step=5) + semantics = create_object_semantics(obj, args) + motion_gen = MotionGenerator( + cfg=MotionGenCfg(planner_cfg=ToppraPlannerCfg(robot_uid=robot.uid)) + ) + hand_open, hand_close = get_hand_open_close_qpos(robot, sim.device) + approach_direction = resolve_approach_direction(args, sim.device) + action_cfgs = build_action_sequence(hand_open, hand_close, approach_direction) + atomic_engine = AtomicActionEngine( + motion_generator=motion_gen, + actions_cfg_list=action_cfgs, + ) + + sim.open_window() + if not args.auto_play: + input(f"Inspect the upright {args.object}, then press Enter to plan...") + + obj_pose = obj.get_local_pose(to_matrix=True) + move_position = obj_pose[0, :3, 3].clone() + move_position[2] = 0.36 + move_target = make_top_down_eef_pose(move_position) + + logger.log_info( + f"Planning move -> pick_up for {args.object} with " + f"approach_direction={format_tensor(approach_direction)}" + ) + start_time = time.time() + is_success, traj = atomic_engine.execute_static( + target_list=[move_target, semantics] + ) + cost_time = time.time() - start_time + logger.log_info(f"Plan trajectory cost time: {cost_time:.2f} seconds") + if not is_success: + logger.log_warning("Failed to plan pickup demo trajectory.") + return + + if not args.auto_play: + input("Press Enter to replay the pickup demo...") + + post_grasp_clear_step = compute_pick_close_end_step() + should_clear_object_dynamics = True + log_stride = max(1, traj.shape[1] // 10) + for i in range(traj.shape[1]): + robot.set_qpos(traj[:, i, :]) + sim.update(step=4) + if should_clear_object_dynamics and i + 1 >= post_grasp_clear_step: + obj.clear_dynamics() + should_clear_object_dynamics = False + logger.log_info(f"Object dynamics cleared after grasp at step={i}") + if args.debug_state and (i % log_stride == 0 or i == traj.shape[1] - 1): + log_object_state(obj, f"replay step {i}/{traj.shape[1] - 1}") + time.sleep(1e-2) + + logger.log_info( + f"PickUpAction keeps the upright {args.object} suspended in the gripper." + ) + + final_qpos = traj[:, -1, :] + for i in range(POST_TRAJECTORY_STEPS): + robot.set_qpos(final_qpos) + sim.update(step=2) + if args.debug_state and i % max(1, POST_TRAJECTORY_STEPS // 5) == 0: + log_object_state(obj, f"post step {i}/{POST_TRAJECTORY_STEPS - 1}") + time.sleep(1e-2) + + if not args.auto_play: + input("Press Enter to exit the simulation...") + + +def main() -> None: + args = parse_arguments() + run_pickup_demo(args) + + +if __name__ == "__main__": + main() diff --git a/tests/sim/atomic_actions/test_actions.py b/tests/sim/atomic_actions/test_actions.py index ba7324cc..7cc881ab 100644 --- a/tests/sim/atomic_actions/test_actions.py +++ b/tests/sim/atomic_actions/test_actions.py @@ -20,16 +20,20 @@ import pytest import torch -from unittest.mock import MagicMock, Mock +from unittest.mock import Mock, patch from embodichain.lab.sim.atomic_actions.core import ( ActionCfg, Affordance, + HeldObjectState, + MoveObjectTarget, ObjectSemantics, ) from embodichain.lab.sim.atomic_actions.actions import ( MoveAction, MoveActionCfg, + MoveObjectAction, + MoveObjectActionCfg, PickUpAction, PickUpActionCfg, PlaceAction, @@ -186,6 +190,57 @@ def test_interpolate_hand_qpos_linear(self): expected_mid = torch.tensor([0.5, 0.5]) assert torch.allclose(result[1], expected_mid, atol=1e-6) + def test_interpolate_hand_qpos_batched_shape(self): + n_waypoints = 4 + start = torch.zeros(NUM_ENVS, HAND_DOF) + end = torch.ones(NUM_ENVS, HAND_DOF) + result = self.action._interpolate_hand_qpos(start, end, n_waypoints) + assert result.shape == (NUM_ENVS, n_waypoints, HAND_DOF) + assert torch.allclose(result[:, 0], start) + assert torch.allclose(result[:, -1], end) + + def test_plan_arm_trajectory_accepts_all_env_success_tensor(self): + n_waypoints = 5 + target_state = Mock() + target_state.xpos = torch.eye(4) + target_states_list = [[target_state] for _ in range(NUM_ENVS)] + start_qpos = torch.zeros(NUM_ENVS, ARM_DOF) + interp_traj = torch.zeros(NUM_ENVS, n_waypoints, ARM_DOF) + + with patch( + "embodichain.lab.sim.atomic_actions.actions.interpolate_with_distance", + return_value=interp_traj, + ): + is_success, trajectory = self.action._plan_arm_trajectory( + target_states_list, + start_qpos, + n_waypoints, + ) + + assert is_success is True + assert trajectory is interp_traj + + def test_plan_arm_trajectory_fails_on_partial_env_success_tensor(self): + target_state = Mock() + target_state.xpos = torch.eye(4) + target_states_list = [[target_state] for _ in range(NUM_ENVS)] + start_qpos = torch.zeros(NUM_ENVS, ARM_DOF) + self.robot.compute_ik = Mock( + return_value=( + torch.tensor([True, False], dtype=torch.bool), + torch.zeros(NUM_ENVS, ARM_DOF), + ) + ) + + is_success, trajectory = self.action._plan_arm_trajectory( + target_states_list, + start_qpos, + n_waypoints=5, + ) + + assert is_success is False + assert trajectory.shape == (NUM_ENVS, 1, ARM_DOF) + # --------------------------------------------------------------------------- # PickUpAction @@ -222,6 +277,173 @@ def test_init_sets_hand_joint_ids(self): assert action.dof == TOTAL_DOF +class TestPickUpActionExecute: + """Tests for PickUpAction execution with batched success flags.""" + + def setup_method(self): + self.robot = _make_mock_robot() + self.mg = _make_mock_motion_generator(self.robot) + + def _make_cfg(self, **overrides): + defaults = dict( + hand_open_qpos=torch.tensor([0.0, 0.0]), + hand_close_qpos=torch.tensor([0.025, 0.025]), + control_part="arm", + hand_control_part="hand", + pre_grasp_distance=0.15, + lift_height=0.15, + approach_direction=torch.tensor([0.0, 0.0, -1.0]), + sample_interval=8, + hand_interp_steps=2, + ) + defaults.update(overrides) + return PickUpActionCfg(**defaults) + + def _make_semantics(self) -> ObjectSemantics: + entity = Mock() + entity.get_local_pose.return_value = torch.eye(4).unsqueeze(0).repeat( + NUM_ENVS, 1, 1 + ) + return ObjectSemantics( + affordance=Affordance(), + geometry={}, + label="box", + entity=entity, + ) + + def test_execute_accepts_all_env_grasp_success_tensor(self): + cfg = self._make_cfg() + action = PickUpAction(self.mg, cfg=cfg) + semantics = self._make_semantics() + grasp_xpos = torch.eye(4).unsqueeze(0).repeat(NUM_ENVS, 1, 1) + action._resolve_grasp_pose = Mock( + return_value=(torch.tensor([True, True], dtype=torch.bool), grasp_xpos) + ) + + def plan_success(target_states_list, start_qpos, n_waypoints, arm_dof): + return True, torch.zeros(NUM_ENVS, n_waypoints, arm_dof) + + action._plan_arm_trajectory = Mock(side_effect=plan_success) + + is_success, trajectory, joint_ids = action.execute( + target=semantics, + start_qpos=torch.zeros(NUM_ENVS, ARM_DOF), + ) + + assert is_success is True + assert joint_ids == list(range(TOTAL_DOF)) + assert trajectory.shape == (NUM_ENVS, cfg.sample_interval, TOTAL_DOF) + assert action._plan_arm_trajectory.call_count == 2 + held_state = action.get_held_object_state() + assert held_state is not None + assert held_state.semantics is semantics + + def test_execute_fails_on_partial_env_grasp_success_tensor(self): + action = PickUpAction(self.mg, cfg=self._make_cfg()) + semantics = self._make_semantics() + grasp_xpos = torch.eye(4).unsqueeze(0).repeat(NUM_ENVS, 1, 1) + action._resolve_grasp_pose = Mock( + return_value=(torch.tensor([True, False], dtype=torch.bool), grasp_xpos) + ) + action._plan_arm_trajectory = Mock() + + is_success, trajectory, joint_ids = action.execute( + target=semantics, + start_qpos=torch.zeros(NUM_ENVS, ARM_DOF), + ) + + assert is_success is False + assert trajectory.numel() == 0 + assert joint_ids == list(range(TOTAL_DOF)) + action._plan_arm_trajectory.assert_not_called() + semantics.entity.get_local_pose.assert_not_called() + assert action.get_held_object_state() is None + + +# --------------------------------------------------------------------------- +# MoveObjectAction +# --------------------------------------------------------------------------- + + +class TestMoveObjectAction: + """Tests for MoveObjectAction without requiring simulation.""" + + def setup_method(self): + self.robot = _make_mock_robot() + self.mg = _make_mock_motion_generator(self.robot) + + def _make_cfg(self, **overrides): + defaults = dict( + hand_close_qpos=torch.tensor([0.025, 0.025]), + control_part="arm", + hand_control_part="hand", + sample_interval=5, + ) + defaults.update(overrides) + return MoveObjectActionCfg(**defaults) + + def _make_held_state(self) -> HeldObjectState: + semantics = ObjectSemantics(affordance=Affordance(), geometry={}, label="box") + object_to_eef = torch.eye(4).unsqueeze(0).repeat(NUM_ENVS, 1, 1) + object_to_eef[:, 2, 3] = 0.2 + grasp_xpos = torch.eye(4).unsqueeze(0).repeat(NUM_ENVS, 1, 1) + return HeldObjectState( + semantics=semantics, + object_to_eef=object_to_eef, + grasp_xpos=grasp_xpos, + ) + + def test_init_sets_hand_joint_ids(self): + action = MoveObjectAction(self.mg, cfg=self._make_cfg()) + assert action.hand_joint_ids == list(range(ARM_DOF, ARM_DOF + HAND_DOF)) + assert action.joint_ids == list(range(ARM_DOF)) + list( + range(ARM_DOF, ARM_DOF + HAND_DOF) + ) + assert action.dof == TOTAL_DOF + + def test_resolve_move_object_target_uses_held_transform(self): + action = MoveObjectAction(self.mg, cfg=self._make_cfg()) + target_pose = torch.eye(4) + target_pose[0, 3] = 0.3 + target = MoveObjectTarget(object_target_pose=target_pose) + is_success, move_xpos, held_state = action._resolve_move_object_target( + target, + held_object_state=self._make_held_state(), + ) + assert is_success is True + assert held_state is not None + assert move_xpos.shape == (NUM_ENVS, 4, 4) + assert torch.allclose(move_xpos[:, 0, 3], torch.full((NUM_ENVS,), 0.3)) + assert torch.allclose(move_xpos[:, 2, 3], torch.full((NUM_ENVS,), 0.2)) + + def test_resolve_move_object_target_requires_held_state(self): + action = MoveObjectAction(self.mg, cfg=self._make_cfg()) + with pytest.raises(ValueError, match="requires a HeldObjectState"): + action._resolve_move_object_target(MoveObjectTarget(torch.eye(4))) + + def test_execute_pads_closed_hand_and_preserves_held_state(self): + cfg = self._make_cfg(sample_interval=4) + action = MoveObjectAction(self.mg, cfg=cfg) + plan_traj = torch.zeros(NUM_ENVS, cfg.sample_interval, ARM_DOF) + action._plan_arm_trajectory = Mock(return_value=(True, plan_traj)) + held_state = self._make_held_state() + + is_success, trajectory, joint_ids = action.execute( + target=MoveObjectTarget(torch.eye(4)), + start_qpos=torch.zeros(NUM_ENVS, ARM_DOF), + held_object_state=held_state, + ) + + assert is_success is True + assert joint_ids == list(range(TOTAL_DOF)) + assert trajectory.shape == (NUM_ENVS, cfg.sample_interval, TOTAL_DOF) + expected_hand = cfg.hand_close_qpos.expand( + NUM_ENVS, cfg.sample_interval, HAND_DOF + ) + assert torch.allclose(trajectory[:, :, ARM_DOF:], expected_hand) + assert action.get_held_object_state() is held_state + + # --------------------------------------------------------------------------- # PlaceAction # --------------------------------------------------------------------------- diff --git a/tests/sim/atomic_actions/test_engine.py b/tests/sim/atomic_actions/test_engine.py index 52dc034d..e390177a 100644 --- a/tests/sim/atomic_actions/test_engine.py +++ b/tests/sim/atomic_actions/test_engine.py @@ -25,6 +25,7 @@ from embodichain.lab.sim.atomic_actions.core import ( ActionCfg, Affordance, + MoveObjectTarget, ObjectSemantics, ) from embodichain.lab.sim.atomic_actions.engine import ( @@ -144,6 +145,11 @@ def test_object_semantics_passthrough(self): result = self.engine._resolve_target(sem) assert result is sem + def test_move_object_target_passthrough(self): + target = MoveObjectTarget(object_target_pose=torch.eye(4)) + result = self.engine._resolve_target(target) + assert result is target + def test_string_resolved_via_semantic_analyzer(self): result = self.engine._resolve_target("mug") assert isinstance(result, ObjectSemantics) @@ -158,6 +164,16 @@ def test_dict_with_pose_raises_on_non_tensor(self): with pytest.raises(TypeError, match="must be a torch.Tensor"): self.engine._resolve_target({"pose": "not_a_tensor"}) + def test_dict_with_object_target_pose(self): + pose = torch.eye(4) + result = self.engine._resolve_target({"object_target_pose": pose}) + assert isinstance(result, MoveObjectTarget) + assert result.object_target_pose is pose + + def test_dict_with_object_target_pose_raises_on_non_tensor(self): + with pytest.raises(TypeError, match="must be a torch.Tensor"): + self.engine._resolve_target({"object_target_pose": "not_a_tensor"}) + def test_dict_with_semantics_key(self): sem = ObjectSemantics(affordance=Affordance(), geometry={}, label="bottle") result = self.engine._resolve_target({"semantics": sem}) @@ -185,6 +201,102 @@ def test_unsupported_type_raises(self): self.engine._resolve_target(42) +# --------------------------------------------------------------------------- +# AtomicActionEngine held-object state contract +# --------------------------------------------------------------------------- + + +class TestHeldObjectStateContract: + """Tests for explicit held-object state updates in execute_static.""" + + def setup_method(self): + self.robot = Mock() + self.robot.device = torch.device("cpu") + self.robot.dof = 6 + self.robot.get_qpos.return_value = torch.zeros(1, 6) + self.robot.get_joint_ids.return_value = list(range(6)) + + self.mg = Mock() + self.mg.robot = self.robot + self.mg.device = torch.device("cpu") + + self.engine = AtomicActionEngine(self.mg, actions_cfg_list=[]) + + def _make_action(self, *, updates_held_object_state: bool, held_state): + action = Mock() + action.control_part = "arm" + action.updates_held_object_state = updates_held_object_state + action.execute.return_value = ( + True, + torch.zeros(1, 2, 6), + list(range(6)), + ) + action.get_held_object_state.return_value = held_state + return action + + def test_execute_static_clears_stale_context_before_running_actions(self): + held_state = object() + action = self._make_action( + updates_held_object_state=False, + held_state=None, + ) + self.engine._actions = {"noop": action} + self.engine._action_context["held_object_state"] = held_state + + is_success, _ = self.engine.execute_static([torch.eye(4)]) + + assert is_success is True + action.get_held_object_state.assert_not_called() + assert "held_object_state" not in self.engine._action_context + + def test_execute_static_updates_state_when_action_declares_contract(self): + held_state = object() + action = self._make_action( + updates_held_object_state=True, + held_state=held_state, + ) + self.engine._actions = {"producer": action} + + is_success, _ = self.engine.execute_static([torch.eye(4)]) + + assert is_success is True + action.get_held_object_state.assert_called_once_with() + assert self.engine._action_context["held_object_state"] is held_state + + def test_execute_static_clears_state_when_action_returns_none(self): + held_state = object() + action = self._make_action( + updates_held_object_state=True, + held_state=None, + ) + self.engine._actions = {"release": action} + self.engine._action_context["held_object_state"] = held_state + + is_success, _ = self.engine.execute_static([torch.eye(4)]) + + assert is_success is True + action.get_held_object_state.assert_called_once_with() + assert "held_object_state" not in self.engine._action_context + + def test_execute_static_keeps_state_within_single_action_sequence(self): + held_state = object() + producer = self._make_action( + updates_held_object_state=True, + held_state=held_state, + ) + release = self._make_action( + updates_held_object_state=True, + held_state=None, + ) + self.engine._actions = {"producer": producer, "release": release} + + is_success, _ = self.engine.execute_static([torch.eye(4), torch.eye(4)]) + + assert is_success is True + assert release.execute.call_args.kwargs["held_object_state"] is held_state + assert "held_object_state" not in self.engine._action_context + + if __name__ == "__main__": test = TestSemanticAnalyzer() test.setup_method()