agenthive icon indicating copy to clipboard operation
agenthive copied to clipboard

Integration with torchRL

Open ShahRutav opened this issue 2 years ago • 3 comments

I modified the getting started example to run torchrl with robohive. Here's the modified example,

import torch
import robohive
from torchrl.envs import RoboHiveEnv
from torchrl.envs import ParallelEnv, TransformedEnv, R3MTransform

from rlhive.rl_envs import make_r3m_env
from torchrl.collectors.collectors import SyncDataCollector, MultiaSyncDataCollector, RandomPolicy
# make sure your ParallelEnv is inside the `if __name__ == "__main__":` condition, otherwise you'll
# be creating an infinite tree of subprocesses
if __name__ == "__main__":
    device = torch.device("cpu") # could be 'cuda:0'
    env_name = 'FrankaReachFixed-v0'
    env = make_r3m_env(env_name, model_name="resnet18", download=True)
    assert env.device == device
    # example of a rollout
    print(env.rollout(3))

Additionally, I changed this line to filter out the visual keys while concatenating R3M transform with other keys to

vec_keys = [k for k in base_env.observation_spec.keys() if ((k != "pixels") and ("visual" not in k))]

This leads to an error -

Traceback (most recent call last):
  File "test.py", line 16, in <module>
    print(env.rollout(3))
  File "/Users/rutavms/miniconda3/envs/agenthive/lib/python3.8/site-packages/torchrl/envs/common.py", line 1797, in rollout
    tensordict = self.reset()
  File "/Users/rutavms/miniconda3/envs/agenthive/lib/python3.8/site-packages/torchrl/envs/common.py", line 1480, in reset
    tensordict_reset = self._reset(tensordict, **kwargs)
  File "/Users/rutavms/miniconda3/envs/agenthive/lib/python3.8/site-packages/torchrl/envs/transforms/transforms.py", line 760, in _reset
    tensordict_reset = self.transform._reset(tensordict, tensordict_reset)
  File "/Users/rutavms/miniconda3/envs/agenthive/lib/python3.8/site-packages/torchrl/envs/transforms/transforms.py", line 1020, in _reset
    tensordict_reset = t._reset(tensordict, tensordict_reset)
  File "/Users/rutavms/miniconda3/envs/agenthive/lib/python3.8/site-packages/torchrl/envs/transforms/transforms.py", line 3694, in _reset
    tensordict_reset = self._call(tensordict_reset)
  File "/Users/rutavms/miniconda3/envs/agenthive/lib/python3.8/site-packages/torchrl/envs/transforms/transforms.py", line 3676, in _call
    out_tensor = torch.cat(values, dim=self.dim)
  File "/Users/rutavms/miniconda3/envs/agenthive/lib/python3.8/site-packages/tensordict/tensordict.py", line 2785, in __torch_function__
    return TD_HANDLED_FUNCTIONS[func](*args, **kwargs)
  File "/Users/rutavms/miniconda3/envs/agenthive/lib/python3.8/site-packages/tensordict/tensordict.py", line 5346, in _cat
    batch_size = list(list_of_tensordicts[0].batch_size)
AttributeError: 'Tensor' object has no attribute 'batch_size'

I am using the following versions of packages: robohive==0.6.0 tensordict==0.2.1 torchrl==0.2.1. Which version did you use? @vmoens

ShahRutav avatar Dec 22 '23 06:12 ShahRutav

On it! Will ping you soon with a solution

vmoens avatar Dec 22 '23 21:12 vmoens

I edited https://github.com/facebookresearch/agenthive/pull/22 You can have a look, this example should work fine now

import torch
import robohive
print(robohive.robohive_env_suite)
# from torchrl.envs import RoboHiveEnv
# from torchrl.envs import ParallelEnv, TransformedEnv, R3MTransform

from rlhive.rl_envs import make_r3m_env
# from torchrl.collectors.collectors import SyncDataCollector, MultiaSyncDataCollector, RandomPolicy
# make sure your ParallelEnv is inside the `if __name__ == "__main__":` condition, otherwise you'll
# be creating an infinite tree of subprocesses
if __name__ == "__main__":
    device = torch.device("cpu") # could be 'cuda:0'
    env_name = 'FrankaReachFixed-v0'
    env = make_r3m_env(env_name, model_name="resnet18", download=True)
    assert env.device == device
    # example of a rollout
    print(env.rollout(3))

vmoens avatar Dec 23 '23 06:12 vmoens

Thanks, env.rollout(3) works after these changes. Taking the test example a step further, I collected data with a single process SyncDataCollector and multi-proc MultiaSyncDataCollector. Below is the code snippet,

import torch
import robohive
# printing the envs in robohive env_suite
# print(robohive.robohive_env_suite)

from rlhive.rl_envs import make_r3m_env
from torchrl.collectors.collectors import SyncDataCollector, MultiaSyncDataCollector, RandomPolicy
# make sure your ParallelEnv is inside the `if __name__ == "__main__":` condition, otherwise you'll
# be creating an infinite tree of subprocesses
if __name__ == "__main__":
    device = torch.device("cpu") # could be 'cuda:0'
    env_name = 'FrankaReachFixed-v0'
    env = make_r3m_env(env_name, model_name="resnet18", download=True)
    assert env.device == device
    # example of a rollout
    print(env.rollout(3))

    # a simple, single-process data collector
    collector = SyncDataCollector(env, policy=RandomPolicy(env.action_spec), total_frames=1_000, frames_per_batch=200, init_random_frames=200, )
    for data in collector:
        print(data)

    # async multi-proc data collector
    collector = MultiaSyncDataCollector([env, env], policy=RandomPolicy(env.action_spec), total_frames=1_000, frames_per_batch=200, init_random_frames=200, )
    for data in collector:
        print(data)
  • env.rollout(3) works without any error.
  • single process SyncDataCollector works as well.
  • however, MultiaSyncDataCollector with two processes leads to an error,
File "/Users/rutavms/miniconda3/envs/agenthive/lib/python3.8/site-packages/torchrl/collectors/collectors.py", line 839, in rollout
    tensordict, tensordict_ = self.env.step_and_maybe_reset(
  File "/Users/rutavms/miniconda3/envs/agenthive/lib/python3.8/site-packages/torchrl/envs/common.py", line 1942, in step_and_maybe_reset
    tensordict = self.step(tensordict)
  File "/Users/rutavms/miniconda3/envs/agenthive/lib/python3.8/site-packages/torchrl/envs/common.py", line 1313, in step
    next_tensordict = self._step(tensordict)
  File "/Users/rutavms/miniconda3/envs/agenthive/lib/python3.8/site-packages/torchrl/envs/transforms/transforms.py", line 735, in _step
    next_tensordict = self.transform._step(tensordict, next_tensordict)
  File "/Users/rutavms/miniconda3/envs/agenthive/lib/python3.8/site-packages/torchrl/envs/transforms/transforms.py", line 970, in _step
    next_tensordict = t._step(tensordict, next_tensordict)
  File "/Users/rutavms/miniconda3/envs/agenthive/lib/python3.8/site-packages/torchrl/envs/transforms/transforms.py", line 318, in _step
    next_tensordict = self._call(next_tensordict)
  File "/Users/rutavms/miniconda3/envs/agenthive/lib/python3.8/site-packages/torchrl/envs/transforms/transforms.py", line 3681, in _call
    raise Exception(
Exception: CatTensor failed, as it expected input keys = ['pixel_r3m', 'qp_robot', 'qv_robot', 'reach_err', 'solved'] but got a TensorDict with keys ['done', 'pixel_r3m', 'qp_robot', 'qv_robot', 'reach_err', 'reward', 'terminated', 'truncated']

The behavior is not the same in SyncDataCollector and MultiaSyncDataCollector.

ShahRutav avatar Dec 23 '23 14:12 ShahRutav