rl
rl copied to clipboard
[BUG] Training an agent on RoboHive env using torchRL
Describe the bug
I wasn't sure whether to open this issue on torchRL, agenthive, or the robohive repo. Apologies if its in the wrong place.
I'm trying to train a PPO agent on the Franka Kitchen environments, and the agent is not able to converge on most of the tasks. I have tried feeding in the full state of the environment as input (body_pos
,body_quat
, site_pos
, site_quat
, qpos
, qvel
, ee_pose
). That failed to converge, and led to the model predicting some invalid actions and causing the robot arm to point straight up in the air, and the following warning:
Simulation couldn't be stepped as intended. Issuing a reset
WARNING:absl:Nan, Inf or huge value in CTRL at ACTUATOR 0. The simulation is unstable. Time = 0.0000.
When I used the site_err
s as input, it converged on a few tasks, but still failed on most.
Is there a canonical way to train an agent using torchRL on a RoboHive env?
To Reproduce
Attached training script for reference. (slightly modified from the ppo tutorial )
Expected behavior
Training converges
System info
Describe the characteristic of your environment:
- Describe how the library was installed - pip
- Python version - v3.8
- Versions of any other relevant libraries - robohive v0.6
import torchrl, numpy, sys
print(torchrl.__version__, numpy.__version__, sys.version, sys.platform)
0.3.0 1.24.4 3.8.18 (default, Nov 29 2023, 16:26:57)
[GCC 13.2.1 20230801] linux
Checklist
- [X] I have checked that there is no similar issue in the repo (required)
- [X] I have read the documentation (required)
- [X] I have provided a minimal working example to reproduce the bug (required)
This is the right place!
Can you share a little more about your setup, how you build the env (are you using transforms?) the PPO loss and how you're using PPO and GAE + storing data?
Thanks!
cc @vikashplus
You can also take pass at some of our pre-trained agents to compare results
Can you share a little more about your setup, how you build the env (are you using transforms?) the PPO loss and how you're using PPO and GAE + storing data?
Creating the env:
def create_env(env_name, device, from_pixels=False, pixels_only=True):
base_env = RoboHiveEnv(
env_name,
device=device,
from_pixels=from_pixels,
pixels_only=pixels_only,
)
env = TransformedEnv(
base_env,
Compose(
FlattenObservation(
first_dim=-2,
last_dim=-1,
in_keys=[
("state", "body_pos"),
("state", "body_quat"),
("state", "site_pos"),
("state", "site_quat"),
],
),
CatTensors(
in_keys=[
("state", "body_pos"),
("state", "body_quat"),
("state", "site_pos"),
("state", "site_quat"),
("state", "qpos"),
("state", "qvel"),
"ee_pose",
],
out_key="observation",
),
# normalize observations
ObservationNorm(in_keys=["observation"]),
DoubleToFloat(),
StepCounter(),
),
)
return env
PPO Loss, GAE and storing data are all identical to the tutorial, I have not made any changes. The full training script is attached in the previous comment, that should help in reproducing the result.
You can also take pass at some of our pre-trained agents to compare results
I did check out the agenthive release, but I had trouble getting it working. It appears to be mjrl
, not torchrl
, and since it depends on the old mujoco bindings, setup was painful. I'll take another crack at it, but I would very much appreciate getting this set up with torchrl!
This seems to be a case of the right input not being passed. I tried training again with just ee_pose
and obj_goal
as the inputs to the policy. This performed marginally better, with more tasks converging. Many others failed to converge and only one "broke" the simulation.
I also checked a bit more on why the simulation breaks and the robot arm points into the air. During training, sometimes, the loss_objective
term of the PPO loss goes to NaN
and after the loss update, the entire network goes to NaN
, which then breaks the simulation. Not sure why this is happening, will try some more experiments and report the results.
@sriramsk1999 - any more updates on this. Here are a few things to check
- Check that observations from following match with the env that you register. Be careful of the seeding
ee.unwrapped.seed(<seed>)
between the two envs.
In [1]: import robohive
In [2]: import gymnasium as gym
In [3]: gym.make('FK1_MicroOpenFixed-v4')
In [4]: ee = gym.make('FK1_MicroOpenFixed-v4')
In [5]: obs = ee.unwrapped.get_obs()
In [6]: reset_obs = ee.unwrapped.reset()
- Check that the exploration of the initial policy is sensible. (small movements around the mean position)
- Check that the policy after 1st PPO iteration is doing sensible things (small movements around the mean position)
cc @vmoens
- The
reset_obs
values match. Before seeding, the gym and torchrl observations do not match. After seeding both the envs, these match as well. - For the policy itself - Last time when I tried training with
MicroOpenRandom-v4
, the loss went to nan and broke the simulation. This time, the loss did not go to nan but the network didn't converge :sweat_smile: - I also tried with
MicroOpenFixed-v4
. Here its better thanMicroOpenRandom-v4
, but still doesn't fully converge, even though its a fixed environment.
I think I am not passing the right inputs to the model. @vmoens what keys of the tensordict should I use to get the vector in ee.unwrapped.get_obs()
?
While running more experiments, I noticed that the training only went to NaN sometimes, on retrying a few times, the policy eventually converged. In this way I was able to get good results for 16 of the 20 Franka Kitchen tasks, only these 4 environments did not succeed - (Stove1Kettle
, Stove4Kettle
, LdoorOpen
, LdoorClose
). I've added the accuracies with these tasks for completeness' sake. Closing this issue now since it seems that there isn't any issue with torchRL / training code, just that the tasks are hard.
Env | Acc (PPO) (%) |
---|---|
Knob1Off | 66 |
Knob1On | 90 |
Knob2Off | 84 |
Knob2On | 82 |
Knob3Off | 80 |
Knob3On | 84 |
Knob4Off | 70 |
Knob4On | 94 |
LightOn | 98 |
LightOff | 98 |
MicroClose | 96 |
MicroOpen | 82 |
SdoorClose | 92 |
SdoorOpen | 98 |
RdoorClose | 98 |
RdoorOpen | 82 |
Average | 87.125 |
Thanks for the updates @sriramsk1999. Much appreciated.