openpi icon indicating copy to clipboard operation
openpi copied to clipboard

Does pi0 model work out of the box in aloha sim?

Open longyuzhao opened this issue 2 months ago • 0 comments

I have set up pi0 model in a remote server and aloha sim in my local device. Basically, I have a script to get the observation from the simulator each timestep, (after certain timesteps) send it to the pi0 model, get the action and send the action to the simulator. I have tried HandOverBanana task with pi0_aloha_sim checkpoint in EnvMode.ALOHA_SIM, and TowelFoldInHalf task with pi0_aloha_towel checkpoint in EnvMode.ALOHA. However, the arms move around but do not seem to do anything meaningful. I also tried replanning the actions every 1/5/10 steps (I am only using the first action in the action trajectory), but does not seem to make a difference. Wonder if this is expected? Or maybe I should sample the actions in a different frequency, etc.? Thanks.

import sys
import os

sys.path.insert(0, os.path.expanduser('~/code/aloha_sim'))

from dm_control import viewer
import numpy as np
from aloha_sim import task_suite
from openpi_client import image_tools
from openpi_client import websocket_client_policy


class Pi0Policy:
    """Policy class that queries Pi0 server for actions."""
    
    def __init__(self, client, task_prompt, replan_frequency=10, execute_steps_per_action=5):
        """
        Args:
            client: WebsocketClientPolicy instance
            task_prompt: Task instruction string
            replan_frequency: Query Pi0 every N steps
            execute_steps_per_action: Hold each action for N steps
        """
        self.client = client
        self.task_prompt = task_prompt
        self.replan_frequency = replan_frequency
        self.execute_steps_per_action = execute_steps_per_action
        
        # State
        self.action_buffer = None
        self.action_index = 0
        self.current_action = None
        self.step_counter = 0
        self.action_hold_counter = 0
        
    def process_observation(self, obs):
        """Convert ALOHA sim observation to Pi0 format."""
        try:
            # Extract images
            img_overhead = obs['overhead_cam']
            img_worms_eye = obs['worms_eye_cam']
            img_wrist_left = obs['wrist_cam_left']
            img_wrist_right = obs['wrist_cam_right']
            
            # Extract joint states
            joint_positions = obs['joints_pos']
            
            # Create Pi0 observation
            observation = {
                "state": joint_positions.astype(np.float64),
                "images": {
                    "cam_high": np.transpose(
                        image_tools.convert_to_uint8(
                            image_tools.resize_with_pad(img_overhead, 224, 224)
                        ), (2, 0, 1)
                    ),
                    "cam_low": np.transpose(
                        image_tools.convert_to_uint8(
                            image_tools.resize_with_pad(img_worms_eye, 224, 224)
                        ), (2, 0, 1)
                    ),
                    "cam_left_wrist": np.transpose(
                        image_tools.convert_to_uint8(
                            image_tools.resize_with_pad(img_wrist_left, 224, 224)
                        ), (2, 0, 1)
                    ),
                    "cam_right_wrist": np.transpose(
                        image_tools.convert_to_uint8(
                            image_tools.resize_with_pad(img_wrist_right, 224, 224)
                        ), (2, 0, 1)
                    ),
                },
                "prompt": self.task_prompt,
            }
            return observation
        except KeyError as e:
            print(f"Missing observation key: {e}")
            print(f"Available keys: {obs.keys()}")
            return None
    
    def __call__(self, timestep):
        """Policy function called by the viewer every step."""
        obs = timestep.observation
        
        # Query Pi0 every replan_frequency steps
        if self.step_counter % self.replan_frequency == 0:
            print(f"\n=== Step {self.step_counter}: Replanning ===")
            
            # Process observation
            pi0_obs = self.process_observation(obs)
            if pi0_obs is None:
                return np.zeros(14)  # ALOHA has 14D action space
            
            try:
                # Get action chunk from Pi0
                action_chunk = self.client.infer(pi0_obs)["actions"]
                self.action_buffer = action_chunk
                self.action_index = 0
                self.action_hold_counter = 0
                print(f"Received {self.action_buffer.shape[0]} actions from Pi0")
            except Exception as e:
                print(f"Error querying Pi0: {e}")
                return np.zeros(14)
        
        # Get next action from buffer every execute_steps_per_action steps
        if self.action_hold_counter % self.execute_steps_per_action == 0:
            if self.action_buffer is not None and self.action_index < len(self.action_buffer):
                self.current_action = self.action_buffer[self.action_index]
                self.action_index += 1
                print(f"  Step {self.step_counter}: Using action {self.action_index}/{len(self.action_buffer)}")
        
        # Hold the current action for multiple steps
        if self.current_action is not None:
            action = self.current_action
        else:
            action = np.zeros(14)
        
        # Increment counters
        self.step_counter += 1
        self.action_hold_counter += 1
        
        # Handle episode end
        if timestep.last():
            print(f"\n{'='*60}")
            print(f"Episode ended at step {self.step_counter}")
            print(f"Final reward: {timestep.reward}")
            print(f"{'='*60}\n")
            self.reset()
        
        return action
    
    def reset(self):
        """Reset policy state for new episode."""
        self.action_buffer = None
        self.action_index = 0
        self.current_action = None
        self.step_counter = 0
        self.action_hold_counter = 0


def main():
    # Create ALOHA sim environment
    env = task_suite.create_task_env('TowelFoldInHalf', time_limit=80.0)
    
    # Connect to Pi0 server
    client = websocket_client_policy.WebsocketClientPolicy(host="localhost", port=8001)
    
    # Create policy
    policy = Pi0Policy(
        client=client,
        task_prompt="fold the towel in half",
        replan_frequency=10,
        execute_steps_per_action=10
    )
    
    # Launch viewer
    print("=" * 60)
    print("Launching ALOHA Sim with Pi0 Policy")
    print("=" * 60)
    print(f"Task: {policy.task_prompt}")
    print(f"Replanning frequency: every {policy.replan_frequency} steps")
    print(f"Action hold duration: {policy.execute_steps_per_action} steps")
    print("Make sure Pi0 server is running on localhost:8001")
    print("=" * 60)
    
    viewer.launch(env, policy=policy)


if __name__ == "__main__":
    main()

longyuzhao avatar Oct 17 '25 23:10 longyuzhao