rl icon indicating copy to clipboard operation
rl copied to clipboard

[BUG] Problems with BatchedEnv on accelerated device with single envs on cpu

Open skandermoalla opened this issue 1 year ago • 29 comments

Describe the bug

When the batched env device is cuda the step count on the batched env seems completely off from what it should be. When the batches env device is mps there is a segmentation fault.

I wonder if this is only the step count that is corrupted or any other data including the observation ...

To Reproduce

from tensordict.nn import TensorDictModule
from torch import nn
from torchrl.envs import (
    EnvCreator,
    ExplorationType,
    StepCounter,
    TransformedEnv, SerialEnv,
)
from torchrl.envs.libs.gym import GymEnv
from torchrl.modules import OneHotCategorical, ProbabilisticActor

max_step = 10
n_env = 4
env_id = "CartPole-v1"
device = "mps"


def build_cpu_single_env():
    env = GymEnv(env_id, device="cpu")
    env = TransformedEnv(env)
    env.append_transform(StepCounter(max_steps=max_step, step_count_key="single_env_step_count", truncated_key="single_env_truncated"))
    return env

def build_actor(env):
    return ProbabilisticActor(
        module=TensorDictModule(
            nn.LazyLinear(env.action_spec.space.n),
            in_keys=["observation"],
            out_keys=["logits"],
        ),
        spec=env.action_spec,
        distribution_class=OneHotCategorical,
        in_keys=["logits"],
        default_interaction_type=ExplorationType.RANDOM,
    )

if __name__ == "__main__":
    env = SerialEnv(n_env, EnvCreator(lambda: build_cpu_single_env()), device=device)
    policy_module = build_actor(env)
    policy_module.to(device)
    policy_module(env.reset())

    for i in range(10):
        batches = env.rollout((max_step + 3), policy=policy_module, break_when_any_done=False)
        max_step_count = batches["next", "single_env_step_count"].max().item()
        if max_step_count > max_step:
            print("Problem!")
            print(max_step_count)
            break
    else:
        print("No problem!")

On CUDA

Problem!
1065353217

On MPS

python(57380,0x1dd5e5000) malloc: Incorrect checksum for freed object 0x11767f308: probably modified after being freed.
Corrupt value: 0xbd414ea83cfeb221
python(57380,0x1dd5e5000) malloc: *** set a breakpoint in malloc_error_break to debug
[1]    57380 abort      python tests/issue_env_device.py

System info

import torchrl, tensordict, torch, numpy, sys
print(torch.__version__, tensordict.__version__, torchrl.__version__, numpy.__version__, sys.version, sys.platform)

2.2.0a0+81ea7a4 0.4.0+eaef29e 0.4.0+01a2216 1.24.4 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0] linux

2.2.0 0.4.0+eaef29e 0.4.0+01a2216 1.26.3 3.10.13 | packaged by conda-forge | (main, Dec 23 2023, 15:35:25) [Clang 16.0.6 ] darwin

skandermoalla avatar Feb 01 '24 17:02 skandermoalla

To reproduce the bug on ParallelEnv you need some wizardry:

  • Change the environment to "MountainCar-v0".
  • Change the truncation key, otherwise you fall into #1865
  • Add an empty transformed env to the batched env.
from tensordict.nn import TensorDictModule
from torch import nn
from torchrl.envs import (
    EnvCreator,
    ExplorationType,
    StepCounter,
    TransformedEnv,
    ParallelEnv,
)
from torchrl.envs.libs.gym import GymEnv
from torchrl.modules import OneHotCategorical, ProbabilisticActor

max_step = 10
n_env = 4
env_id = "MountainCar-v0"
device = "cuda:0"


def build_cpu_single_env():
    env = GymEnv(env_id, device="cpu")
    env = TransformedEnv(env)
    env.append_transform(StepCounter(max_steps=max_step, truncated_key="foo"))
    return env


def build_actor(env):
    return ProbabilisticActor(
        module=TensorDictModule(
            nn.LazyLinear(env.action_spec.space.n),
            in_keys=["observation"],
            out_keys=["logits"],
        ),
        spec=env.action_spec,
        distribution_class=OneHotCategorical,
        in_keys=["logits"],
        default_interaction_type=ExplorationType.RANDOM,
    )


if __name__ == "__main__":
    env = ParallelEnv(n_env, EnvCreator(lambda: build_cpu_single_env()), device=device)
    env = TransformedEnv(env)
    policy_module = build_actor(env)
    policy_module.to(device)
    policy_module(env.reset())
    for i in range(10):
        batches = env.rollout((max_step + 3), policy=policy_module, break_when_any_done=False)
        max_step_count = batches["next", "step_count"].max().item()
        if max_step_count > max_step:
            print("Problem!")
            print(max_step_count)
            break
    else:
        print("No problem!")

skandermoalla avatar Feb 01 '24 19:02 skandermoalla

I can reprod the initial example iif the env is on "cpu" so it's likely just a problem of casting from device to device in serial env I will check that tomorrow!

vmoens avatar Feb 01 '24 21:02 vmoens

Nice, thanks! Indeed it's probably device casting gone wrong somewhere as MPS crashed with segfault literally. Could you reproduce the one with ParallelEn? That's as impactful as the SerialEnv one.

skandermoalla avatar Feb 02 '24 10:02 skandermoalla

Can you have a go at 1866 for cpu envs? With me it works on sub-envs on cpu and cuda (even with 100 outer steps)

vmoens avatar Feb 02 '24 12:02 vmoens

Also I ran the second example with 10K outer iteration but could not reprod the issue (on the branch of the PR but I did not change much from main) so I'm not sure how to address this

VERBOSE=1 python -c """import tqdm
from tensordict.nn import TensorDictModule
from torch import nn
from torchrl.envs import (
    EnvCreator,
    ExplorationType,
    StepCounter,
    TransformedEnv,
    ParallelEnv,
)
from torchrl.envs.libs.gym import GymEnv
from torchrl.modules import OneHotCategorical, ProbabilisticActor

max_step = 10
n_env = 4
env_id = 'MountainCar-v0'
device = 'cuda:0'


def build_cpu_single_env():
    env = GymEnv(env_id, device='cpu')
    env = TransformedEnv(env)
    env.append_transform(StepCounter(max_steps=max_step, truncated_key='foo'))
    return env


def build_actor(env):
    return ProbabilisticActor(
        module=TensorDictModule(
            nn.LazyLinear(env.action_spec.space.n),
            in_keys=['observation'],
            out_keys=['logits'],
        ),
        spec=env.action_spec,
        distribution_class=OneHotCategorical,
        in_keys=['logits'],
        default_interaction_type=ExplorationType.RANDOM,
    )


if __name__ == '__main__':
    env = ParallelEnv(n_env, EnvCreator(lambda: build_cpu_single_env()), device=device)
    env = TransformedEnv(env)
    policy_module = build_actor(env)
    policy_module.to(device)
    policy_module(env.reset())
    for i in tqdm.tqdm(range(10000)):
        batches = env.rollout((max_step + 3), policy=policy_module, break_when_any_done=False)
        max_step_count = batches['next', 'step_count'].max().item()
        if max_step_count > max_step:
            print('Problem!')
            print(max_step_count)
            print(batches['next', 'step_count'])
            break
    else:
        print('No problem!')
"""

vmoens avatar Feb 02 '24 12:02 vmoens

VERBOSE=1 python -c """import tqdm                                                                                                                                                                                                                 
from tensordict.nn import TensorDictModule
from torch import nn                                             
from torchrl.envs import (
    EnvCreator,                                                               
    ExplorationType,                                                                   
    StepCounter,         
    TransformedEnv,
    ParallelEnv,
)
from torchrl.envs.libs.gym import GymEnv
from torchrl.modules import OneHotCategorical, ProbabilisticActor

max_step = 10                                                                 
n_env = 4                                                                              
env_id = 'MountainCar-v0'
device = 'cuda:0'

...
/usr/local/lib/python3.10/dist-packages/gymnasium/core.py:311: UserWarning: WARN: env.num_envs to get variables from other wrappers is deprecated and will be removed in v1.0, to get this variable you can do `env.unwrapped.num_envs` for environment variables or `env.get_wrapper_attr('num_envs')` that will search the reminding wrappers.
  logger.warn(
/usr/local/lib/python3.10/dist-packages/gymnasium/core.py:311: UserWarning: WARN: env.reward_space to get variables from other wrappers is deprecated and will be removed in v1.0, to get this variable you can do `env.unwrapped.reward_space` for environment variables or `env.get_wrapper_attr('reward_space')` that will search the reminding wrappers.
  logger.warn(
2024-02-02 13:10:48,627 [torchrl][INFO] resetting implement_for
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/lazy.py:180: UserWarning: Lazy modules are a new feature under heavy development so changes to the API or functionality can happen at any moment.
  warnings.warn('Lazy modules are a new feature under heavy development '
2024-02-02 13:10:48,872 [torchrl][INFO] initiating worker 0
2024-02-02 13:10:48,965 [torchrl][INFO] initiating worker 1
2024-02-02 13:10:48,967 [torchrl][INFO] initiating worker 2
2024-02-02 13:10:48,969 [torchrl][INFO] initiating worker 3
2024-02-02 13:10:51,116 [torchrl][INFO] resetting implement_for
2024-02-02 13:10:51,142 [torchrl][INFO] resetting implement_for
2024-02-02 13:10:51,192 [torchrl][INFO] resetting implement_for
2024-02-02 13:10:51,196 [torchrl][INFO] resetting implement_for
  0%|                                                                                                                                                                                                                      | 0/10000 [00:00<?, ?it/s]Problem!
3201372667
tensor([[[         1],
         [         2],
         [         3],
         [         4],
         [         5],
         [         6],
         [         7],
         [         8],
         [         9],
         [        10],
         [         1],
         [         1],
         [         1]],

        [[         1],
         [         2],
         [         3],
         [         4],
         [         5],
         [         6],
         [         7],
         [         8],
         [         9],
         [        10],
         [         1],
         [         1],
         [3201372667]],

        [[         1],
         [         2],
         [         3],
         [         4],
         [         5],
         [         6],
         [         7],
         [         8],
         [         9],
         [        10],
         [         1],
         [         1],
         [         1]],

        [[         1],
         [         2],
         [         3],
         [         4],
         [         5],
         [         6],
         [         7],
         [         8],
         [         9],
         [        10],
         [         1],
         [         1],
         [         1]]], device='cuda:0')
  0%|            
❯ python                                   
Python 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import torchrl, tensordict, torch, numpy, sys

>>> print(torch.__version__, tensordict.__version__, torchrl.__version__, numpy.__version__, sys.version, sys.platform)
2.2.0a0+81ea7a4 0.4.0+eaef29e 0.4.0+01a2216 1.24.4 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0] linux

I will try with the 1866 branch and on another cluster.

skandermoalla avatar Feb 02 '24 13:02 skandermoalla

SerialEnv example was solved with #1866. I also tried poking a bit and it was fine.

I will the ParallelEnv one.

skandermoalla avatar Feb 02 '24 13:02 skandermoalla

The problem is now different with ParallelEnv that's why it probably didn't error for you @vmoens.

VERBOSE=1 python -c """import tqdm
from tensordict.nn import TensorDictModule
from torch import nn
from torchrl.envs import (
    EnvCreator,
    ExplorationType,
    StepCounter,
    TransformedEnv,
    ParallelEnv,
)
from torchrl.envs.libs.gym import GymEnv
from torchrl.modules import OneHotCategorical, ProbabilisticActor

max_step = 10
n_env = 4
env_id = 'MountainCar-v0'
device = 'cuda:0'


def build_cpu_single_env():
    env = GymEnv(env_id, device='cpu')
    env = TransformedEnv(env)
    env.append_transform(StepCounter(max_steps=max_step, truncated_key='foo'))
    return env


def build_actor(env):
    return ProbabilisticActor(
        module=TensorDictModule(
            nn.LazyLinear(env.action_spec.space.n),
            in_keys=['observation'],
            out_keys=['logits'],
        ),
        spec=env.action_spec,
        distribution_class=OneHotCategorical,
        in_keys=['logits'],
        default_interaction_type=ExplorationType.RANDOM,
    )


if __name__ == '__main__':
    env = ParallelEnv(n_env, EnvCreator(lambda: build_cpu_single_env()), device=device)
    env = TransformedEnv(env)
    policy_module = build_actor(env)
    policy_module.to(device)
    policy_module(env.reset())
    for i in tqdm.tqdm(range(10)):
        batches = env.rollout((max_step + 3), policy=policy_module, break_when_any_done=False)
        max_step_count = batches['next', 'step_count'].max().item()
        if max_step_count > max_step:
            print('Problem 1!')
            print(max_step_count)
            print(batches['next', 'step_count'])
            break
        elif max_step_count < max_step:
            print('Problem 2!')
            print(max_step_count)
            print(batches['next', 'step_count'])
            break
    else:
        print('No problem!')
"""
/usr/local/lib/python3.10/dist-packages/gymnasium/core.py:311: UserWarning: WARN: env.num_envs to get variables from other wrappers is deprecated and will be removed in v1.0, to get this variable you can do `env.unwrapped.num_envs` for environment variables or `env.get_wrapper_attr('num_envs')` that will search the reminding wrappers.
  logger.warn(
/usr/local/lib/python3.10/dist-packages/gymnasium/core.py:311: UserWarning: WARN: env.reward_space to get variables from other wrappers is deprecated and will be removed in v1.0, to get this variable you can do `env.unwrapped.reward_space` for environment variables or `env.get_wrapper_attr('reward_space')` that will search the reminding wrappers.
  logger.warn(
2024-02-02 13:44:03,727 [torchrl][INFO] resetting implement_for
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/lazy.py:180: UserWarning: Lazy modules are a new feature under heavy development so changes to the API or functionality can happen at any moment.
  warnings.warn('Lazy modules are a new feature under heavy development '
2024-02-02 13:44:03,981 [torchrl][INFO] initiating worker 0
2024-02-02 13:44:04,041 [torchrl][INFO] initiating worker 1
2024-02-02 13:44:04,043 [torchrl][INFO] initiating worker 2
2024-02-02 13:44:04,045 [torchrl][INFO] initiating worker 3
2024-02-02 13:44:06,233 [torchrl][INFO] resetting implement_for
2024-02-02 13:44:06,241 [torchrl][INFO] resetting implement_for
2024-02-02 13:44:06,255 [torchrl][INFO] resetting implement_for
2024-02-02 13:44:06,273 [torchrl][INFO] resetting implement_for
 10%|████████████████████▉                                                                                                                                                                                            | 1/10 [00:00<00:07,  1.23it/s]Problem 2!
1
tensor([[[1],
         [1],
         [1],
         [1],
         [1],
         [1],
         [1],
         [1],
         [1],
         [1],
         [1],
         [1],
         [1]],

        [[1],
         [1],
         [1],
         [1],
         [1],
         [1],
         [1],
         [1],
         [1],
         [1],
         [1],
         [1],
         [1]],

        [[1],
         [1],
         [1],
         [1],
         [1],
         [1],
         [1],
         [1],
         [1],
         [1],
         [1],
         [1],
         [1]],

        [[1],
         [1],
         [1],
         [1],
         [1],
         [1],
         [1],
         [1],
         [1],
         [1],
         [1],
         [1],
         [1]]], device='cuda:0')
 10%|████████████████████▉      
❯ python
Python 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import torchrl, tensordict, torch, numpy, sys
>>> print(torch.__version__, tensordict.__version__, torchrl.__version__, numpy.__version__, sys.version, sys.platform)
2.2.0a0+81ea7a4 0.4.0+eaef29e 0.4.0+1ea3c74 1.24.4 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0] linux

I will try on a different cluster.

skandermoalla avatar Feb 02 '24 13:02 skandermoalla

Happens on two different clusters with different CPUs and GPUs (same Docker image though, the NVIDIA NGC PyTorch).

skandermoalla avatar Feb 02 '24 13:02 skandermoalla

On MPS it's not segfault anymore but the original arbitrary number bug:

from tensordict.nn import TensorDictModule
from torch import nn
from torchrl.envs import (
    EnvCreator,
    ExplorationType,
    StepCounter,
    TransformedEnv,
    SerialEnv,
)
from torchrl.envs.libs.gym import GymEnv
from torchrl.modules import OneHotCategorical, ProbabilisticActor

max_step = 10
n_env = 4
env_id = "CartPole-v1"
device = "mps"


def build_cpu_single_env():
    env = GymEnv(env_id, device="cpu")
    env = TransformedEnv(env)
    env.append_transform(StepCounter(max_steps=max_step))
    return env


def build_actor(env):
    return ProbabilisticActor(
        module=TensorDictModule(
            nn.LazyLinear(env.action_spec.space.n),
            in_keys=["observation"],
            out_keys=["logits"],
        ),
        spec=env.action_spec,
        distribution_class=OneHotCategorical,
        in_keys=["logits"],
        default_interaction_type=ExplorationType.RANDOM,
    )


if __name__ == "__main__":
    env = SerialEnv(n_env, EnvCreator(lambda: build_cpu_single_env()), device=device)
    policy_module = build_actor(env)
    policy_module.to(device)
    policy_module(env.reset())

    for i in range(10):
        batches = env.rollout((max_step + 3), policy=policy_module, break_when_any_done=False)
        max_step_count = batches["next", "step_count"].max().item()
        print(max_step_count)
        print(batches["next", "step_count"])
        if max_step_count > max_step:
            print("Problem!")
            print(max_step_count)
            break
    else:
        print("No problem!")

gives

/Users/moalla/mambaforge/envs/torchrl/lib/python3.10/site-packages/gymnasium/core.py:311: UserWarning: WARN: env.num_envs to get variables from other wrappers is deprecated and will be removed in v1.0, to get this variable you can do `env.unwrapped.num_envs` for environment variables or `env.get_wrapper_attr('num_envs')` that will search the reminding wrappers.
  logger.warn(
/Users/moalla/mambaforge/envs/torchrl/lib/python3.10/site-packages/gymnasium/core.py:311: UserWarning: WARN: env.reward_space to get variables from other wrappers is deprecated and will be removed in v1.0, to get this variable you can do `env.unwrapped.reward_space` for environment variables or `env.get_wrapper_attr('reward_space')` that will search the reminding wrappers.
  logger.warn(
/Users/moalla/mambaforge/envs/torchrl/lib/python3.10/site-packages/torch/nn/modules/lazy.py:181: UserWarning: Lazy modules are a new feature under heavy development so changes to the API or functionality can happen at any moment.
  warnings.warn('Lazy modules are a new feature under heavy development '
/Users/moalla/mambaforge/envs/torchrl/lib/python3.10/site-packages/gymnasium/core.py:311: UserWarning: WARN: env.num_envs to get variables from other wrappers is deprecated and will be removed in v1.0, to get this variable you can do `env.unwrapped.num_envs` for environment variables or `env.get_wrapper_attr('num_envs')` that will search the reminding wrappers.
  logger.warn(
/Users/moalla/mambaforge/envs/torchrl/lib/python3.10/site-packages/gymnasium/core.py:311: UserWarning: WARN: env.reward_space to get variables from other wrappers is deprecated and will be removed in v1.0, to get this variable you can do `env.unwrapped.reward_space` for environment variables or `env.get_wrapper_attr('reward_space')` that will search the reminding wrappers.
  logger.warn(
/Users/moalla/mambaforge/envs/torchrl/lib/python3.10/site-packages/gymnasium/core.py:311: UserWarning: WARN: env.num_envs to get variables from other wrappers is deprecated and will be removed in v1.0, to get this variable you can do `env.unwrapped.num_envs` for environment variables or `env.get_wrapper_attr('num_envs')` that will search the reminding wrappers.
  logger.warn(
/Users/moalla/mambaforge/envs/torchrl/lib/python3.10/site-packages/gymnasium/core.py:311: UserWarning: WARN: env.reward_space to get variables from other wrappers is deprecated and will be removed in v1.0, to get this variable you can do `env.unwrapped.reward_space` for environment variables or `env.get_wrapper_attr('reward_space')` that will search the reminding wrappers.
  logger.warn(
/Users/moalla/mambaforge/envs/torchrl/lib/python3.10/site-packages/gymnasium/core.py:311: UserWarning: WARN: env.num_envs to get variables from other wrappers is deprecated and will be removed in v1.0, to get this variable you can do `env.unwrapped.num_envs` for environment variables or `env.get_wrapper_attr('num_envs')` that will search the reminding wrappers.
  logger.warn(
/Users/moalla/mambaforge/envs/torchrl/lib/python3.10/site-packages/gymnasium/core.py:311: UserWarning: WARN: env.reward_space to get variables from other wrappers is deprecated and will be removed in v1.0, to get this variable you can do `env.unwrapped.reward_space` for environment variables or `env.get_wrapper_attr('reward_space')` that will search the reminding wrappers.
  logger.warn(
/Users/moalla/mambaforge/envs/torchrl/lib/python3.10/site-packages/gymnasium/core.py:311: UserWarning: WARN: env.num_envs to get variables from other wrappers is deprecated and will be removed in v1.0, to get this variable you can do `env.unwrapped.num_envs` for environment variables or `env.get_wrapper_attr('num_envs')` that will search the reminding wrappers.
  logger.warn(
/Users/moalla/mambaforge/envs/torchrl/lib/python3.10/site-packages/gymnasium/core.py:311: UserWarning: WARN: env.reward_space to get variables from other wrappers is deprecated and will be removed in v1.0, to get this variable you can do `env.unwrapped.reward_space` for environment variables or `env.get_wrapper_attr('reward_space')` that will search the reminding wrappers.
  logger.warn(
5157210208
tensor([[[         6],
         [         6],
         [         1],
         [         6],
         [         0],
         [         6],
         [         6],
         [         6],
         [         6],
         [         6],
         [         6],
         [         1],
         [         6]],

        [[         1],
         [         1],
         [         1],
         [         1],
         [5157210208],
         [         1],
         [         2],
         [         1],
         [         1],
         [         2],
         [         1],
         [         2],
         [         1]],

        [[         1],
         [         1],
         [         1],
         [         1],
         [         6],
         [         1],
         [         2],
         [         1],
         [         1],
         [         2],
         [         1],
         [         2],
         [         1]],

        [[         1],
         [         1],
         [         1],
         [         1],
         [         1],
         [         1],
         [         2],
         [         1],
         [         1],
         [         2],
         [         1],
         [         2],
         [         1]]], device='mps:0')
Problem!
5157210208

skandermoalla avatar Feb 02 '24 14:02 skandermoalla

MPS still gives segfault for ParallelEnv.

skandermoalla avatar Feb 02 '24 14:02 skandermoalla

I think it's solved now (for cuda on serial and parallel on the bugfix PR). I will have a look at mps later!

vmoens avatar Feb 04 '24 21:02 vmoens

Does this need a specific branch on tensordict?

skandermoalla avatar Feb 05 '24 16:02 skandermoalla

Yeah sorry I'm patching TensorDict let me quickly revert the latest changes which should be part of the another PR

vmoens avatar Feb 05 '24 16:02 vmoens

I changed it, and tests seem to be passing. If they all do, I'll do a final run of your examples and check the status on MPS. If it all runs smoothly I will consider the PR as good unless you wish to do a proper review of it.

vmoens avatar Feb 05 '24 17:02 vmoens

I'll poke a bit now and give my feedback soon. So I should test with this branch on TorchRL and main on Tensordict?

skandermoalla avatar Feb 05 '24 17:02 skandermoalla

Yes tensordict main is up to date

vmoens avatar Feb 05 '24 17:02 vmoens

All good for CUDA! Awesome! (tested some scripts but didn't check the PR code)

❯ python
Python 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import torchrl, tensordict, torch, numpy, sys
>>> print(torch.__version__, tensordict.__version__, torchrl.__version__, numpy.__version__, sys.version, sys.platform)
2.2.0a0+81ea7a4 0.4.0+99705db 0.4.0+1f485e9 1.24.4 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0] linux

(Played with variations of things like here https://github.com/skandermoalla/TorchRL/tree/34c8abf19fd5a5177a2d5eadd5a5b1f57d51ab6c/tests)

skandermoalla avatar Feb 05 '24 18:02 skandermoalla

Testing for MPS.

skandermoalla avatar Feb 05 '24 18:02 skandermoalla

Not yet for MPS.

For SerialEnv (https://github.com/skandermoalla/TorchRL/blob/34c8abf19fd5a5177a2d5eadd5a5b1f57d51ab6c/tests/issue_env_device_serial.py) I have different errors that appear arbitrarily:

python(22437,0x1d9879300) malloc: tiny_free_list_remove_ptr: Internal invariant broken (next ptr of prev): ptr=0x139ced580, prev_next=0x0
python(22437,0x1d9879300) malloc: *** set a breakpoint in malloc_error_break to debug
[1]    22437 abort      python tests/issue_env_device_serial.py
Traceback (most recent call last):
  File "/Users/skander/projects/open-source/TorchRL/tests/issue_env_device_serial.py", line 47, in <module>
    batches = env.rollout((2 * max_step + 3), policy=policy_module, break_when_any_done=False)
  File "/Users/skander/projects/open-source/TorchRL/repos/rl/torchrl/envs/common.py", line 2395, in rollout
    tensordicts = self._rollout_nonstop(**kwargs)
  File "/Users/skander/projects/open-source/TorchRL/repos/rl/torchrl/envs/common.py", line 2484, in _rollout_nonstop
    tensordict, tensordict_ = self.step_and_maybe_reset(tensordict_)
  File "/Users/skander/projects/open-source/TorchRL/repos/rl/torchrl/envs/common.py", line 2554, in step_and_maybe_reset
    tensordict_ = self.reset(tensordict_)
  File "/Users/skander/projects/open-source/TorchRL/repos/rl/torchrl/envs/common.py", line 2056, in reset
    tensordict_reset = self._reset(tensordict, **kwargs)
  File "/Users/skander/projects/open-source/TorchRL/repos/rl/torchrl/envs/batched_envs.py", line 58, in decorated_fun
    return fun(self, *args, **kwargs)
  File "/Users/skander/projects/open-source/TorchRL/repos/rl/torchrl/envs/batched_envs.py", line 781, in _reset
    _td = _env.reset(tensordict=tensordict_, **kwargs)
  File "/Users/skander/projects/open-source/TorchRL/repos/rl/torchrl/envs/common.py", line 2071, in reset
    return self._reset_proc_data(tensordict, tensordict_reset)
  File "/Users/skander/projects/open-source/TorchRL/repos/rl/torchrl/envs/transforms/transforms.py", line 795, in _reset_proc_data
    self._reset_check_done(tensordict, tensordict_reset)
  File "/Users/skander/projects/open-source/TorchRL/repos/rl/torchrl/envs/common.py", line 2103, in _reset_check_done
    raise RuntimeError(
RuntimeError: Env done entry 'truncated' was (partially) True after reset on specified '_reset' dimensions. This is not allowed.

and

python(22700,0x16dcdb000) malloc: Incorrect checksum for freed object 0x13b396f08: probably modified after being freed.
Corrupt value: 0xbc99eb7cbad79154
python(22700,0x16dcdb000) malloc: *** set a breakpoint in malloc_error_break to debug
[1]    22700 abort      python tests/issue_env_device_serial.py

For ParallelEnv (https://github.com/skandermoalla/TorchRL/blob/34c8abf19fd5a5177a2d5eadd5a5b1f57d51ab6c/tests/issue_env_device_parallel.py) I also have arbitrary errors

Traceback (most recent call last):
  File "/Users/skander/projects/open-source/TorchRL/tests/issue_env_device_parallel.py", line 45, in <module>
    policy_module(env.reset())
  File "/Users/skander/projects/open-source/TorchRL/repos/rl/torchrl/envs/common.py", line 2071, in reset
    return self._reset_proc_data(tensordict, tensordict_reset)
  File "/Users/skander/projects/open-source/TorchRL/repos/rl/torchrl/envs/transforms/transforms.py", line 795, in _reset_proc_data
    self._reset_check_done(tensordict, tensordict_reset)
  File "/Users/skander/projects/open-source/TorchRL/repos/rl/torchrl/envs/common.py", line 2123, in _reset_check_done
    raise RuntimeError(
RuntimeError: The done entry 'truncated' was (partially) True after a call to reset() in env TransformedEnv(
    env=ParallelEnv(
        env=TransformedEnv(
        env=GymEnv(env=MountainCar-v0, batch_size=torch.Size([]), device=cpu),
        transform=Compose(
                StepCounter(keys=[]))), 
        batch_size=torch.Size([4])),
    transform=Compose(
    )).

and

Traceback (most recent call last):
  File "/Users/skander/projects/open-source/TorchRL/tests/issue_env_device_parallel.py", line 47, in <module>
    batches = env.rollout((max_step + 3), policy=policy_module, break_when_any_done=False)
  File "/Users/skander/projects/open-source/TorchRL/repos/rl/torchrl/envs/common.py", line 2395, in rollout
    tensordicts = self._rollout_nonstop(**kwargs)
  File "/Users/skander/projects/open-source/TorchRL/repos/rl/torchrl/envs/common.py", line 2484, in _rollout_nonstop
    tensordict, tensordict_ = self.step_and_maybe_reset(tensordict_)
  File "/Users/skander/projects/open-source/TorchRL/repos/rl/torchrl/envs/common.py", line 2554, in step_and_maybe_reset
    tensordict_ = self.reset(tensordict_)
  File "/Users/skander/projects/open-source/TorchRL/repos/rl/torchrl/envs/common.py", line 2056, in reset
    tensordict_reset = self._reset(tensordict, **kwargs)
  File "/Users/skander/projects/open-source/TorchRL/repos/rl/torchrl/envs/transforms/transforms.py", line 785, in _reset
    tensordict_reset = self.base_env._reset(tensordict=tensordict, **kwargs)
  File "/Users/skander/projects/open-source/TorchRL/repos/rl/torchrl/envs/batched_envs.py", line 58, in decorated_fun
    return fun(self, *args, **kwargs)
  File "/Users/skander/projects/open-source/TorchRL/repos/rl/torchrl/envs/batched_envs.py", line 1298, in _reset
    channel.send(out)
  File "/Users/skander/mambaforge/envs/torchrl/lib/python3.10/multiprocessing/connection.py", line 206, in send
    self._send_bytes(_ForkingPickler.dumps(obj))
  File "/Users/skander/mambaforge/envs/torchrl/lib/python3.10/multiprocessing/reduction.py", line 51, in dumps
    cls(buf, protocol).dump(obj)
  File "/Users/skander/mambaforge/envs/torchrl/lib/python3.10/site-packages/torch/multiprocessing/reductions.py", line 557, in reduce_storage
    metadata = storage._share_filename_cpu_()
  File "/Users/skander/mambaforge/envs/torchrl/lib/python3.10/site-packages/torch/storage.py", line 294, in wrapper
    return fn(self, *args, **kwargs)
  File "/Users/skander/mambaforge/envs/torchrl/lib/python3.10/site-packages/torch/storage.py", line 368, in _share_filename_cpu_
    return super()._share_filename_cpu_(*args, **kwargs)
RuntimeError: _share_filename_: only available on CPU

skandermoalla avatar Feb 05 '24 19:02 skandermoalla

Those seem to be different issues than the CUDA ones, so I think we should go ahead with the PR and make sure these things are working ok with MPS separately!

vmoens avatar Feb 05 '24 20:02 vmoens

Reopening to keep track of progress with MPS

vmoens avatar Feb 05 '24 20:02 vmoens

If I avoid updating slices (see this bug) (which happens if you create a single copy of the env) I have no issue with the following code

from tensordict.nn import TensorDictModule
from torch import nn
from torchrl.envs import (
    EnvCreator,
    ExplorationType,
    StepCounter,
    TransformedEnv,
    SerialEnv,
)
from torchrl.envs.libs.gym import GymEnv
from torchrl.modules import OneHotCategorical, ProbabilisticActor

max_step = 10
n_env = 4
env_id = "CartPole-v1"
device = "mps"


def build_cpu_single_env():
    env = GymEnv(env_id, device="cpu")
    env = TransformedEnv(env)
    env.append_transform(StepCounter(max_steps=max_step))
    return env


def build_actor(env):
    return ProbabilisticActor(
        module=TensorDictModule(
            nn.LazyLinear(env.action_spec.space.n),
            in_keys=["observation"],
            out_keys=["logits"],
        ),
        spec=env.action_spec,
        distribution_class=OneHotCategorical,
        in_keys=["logits"],
        default_interaction_type=ExplorationType.RANDOM,
    )


if __name__ == "__main__":
    env = SerialEnv(n_env, [EnvCreator(build_cpu_single_env) for _ in range(n_env)], device=device)
    policy_module = build_actor(env)
    policy_module.to(device)
    policy_module(env.reset())

    for i in range(10):
        batches = env.rollout((max_step + 3), policy=policy_module, break_when_any_done=False)
        max_step_count = batches["next", "step_count"].max().item()
        print(max_step_count)
        print(batches["next", "step_count"])
        if max_step_count > max_step:
            print("Problem!")
            print(max_step_count)
            break
    else:
        print("No problem!")

Notice the way I create the serial env (I don't use the same EnvCreator)

I also have no issue with ParallelEnv.

vmoens avatar Feb 07 '24 09:02 vmoens

Not for me. Running the above script gives me the same transient errors I described. Which commits are you using? I'm on the main branches of both TorchRL and TensorDict. Here's my environment:

torchrl ❯ mamba env export                                                        
name: torchrl
channels:
  - pytorch
  - conda-forge
dependencies:
  - brotli=1.1.0=hb547adb_1
  - brotli-bin=1.1.0=hb547adb_1
  - bzip2=1.0.8=h93a5062_5
  - ca-certificates=2023.11.17=hf0a4a13_0
  - certifi=2023.11.17=pyhd8ed1ab_0
  - colorama=0.4.6=pyhd8ed1ab_0
  - contourpy=1.2.0=py310hd137fd4_0
  - cycler=0.12.1=pyhd8ed1ab_0
  - exceptiongroup=1.2.0=pyhd8ed1ab_2
  - filelock=3.13.1=pyhd8ed1ab_0
  - fonttools=4.47.2=py310hd125d64_0
  - freetype=2.12.1=hadb7bae_2
  - gmp=6.3.0=h965bd2d_0
  - gmpy2=2.1.2=py310h2e6cad2_1
  - imageio=2.33.1=pyh8c1a49c_0
  - iniconfig=2.0.0=pyhd8ed1ab_0
  - jinja2=3.1.3=pyhd8ed1ab_0
  - kiwisolver=1.4.5=py310h38f39d4_1
  - lcms2=2.16=ha0e7c42_0
  - lerc=4.0.0=h9a09cb3_0
  - libblas=3.9.0=21_osxarm64_openblas
  - libbrotlicommon=1.1.0=hb547adb_1
  - libbrotlidec=1.1.0=hb547adb_1
  - libbrotlienc=1.1.0=hb547adb_1
  - libcblas=3.9.0=21_osxarm64_openblas
  - libcxx=16.0.6=h4653b0c_0
  - libdeflate=1.19=hb547adb_0
  - libffi=3.4.2=h3422bc3_5
  - libgfortran=5.0.0=13_2_0_hd922786_2
  - libgfortran5=13.2.0=hf226fd6_2
  - libjpeg-turbo=3.0.0=hb547adb_1
  - liblapack=3.9.0=21_osxarm64_openblas
  - libopenblas=0.3.26=openmp_h6c19121_0
  - libpng=1.6.39=h76d750c_0
  - libsqlite=3.44.2=h091b4b1_0
  - libtiff=4.6.0=ha8a6c65_2
  - libwebp-base=1.3.2=hb547adb_0
  - libxcb=1.15=hf346824_0
  - libzlib=1.2.13=h53f4e23_5
  - llvm-openmp=17.0.6=hcd81f8e_0
  - markupsafe=2.1.4=py310hd125d64_0
  - matplotlib=3.8.2=py310hb6292c7_0
  - matplotlib-base=3.8.2=py310h9d2df84_0
  - mpc=1.3.1=h91ba8db_0
  - mpfr=4.2.1=h9546428_0
  - mpmath=1.3.0=pyhd8ed1ab_0
  - munkres=1.1.4=pyh9f0ad1d_0
  - ncurses=6.4=h463b476_2
  - networkx=3.2.1=pyhd8ed1ab_0
  - numpy=1.26.3=py310hd45542a_0
  - openjpeg=2.5.0=h4c1507b_3
  - openssl=3.2.1=h0d3ecfb_0
  - packaging=23.2=pyhd8ed1ab_0
  - pcre2=10.42=h26f9a81_0
  - pillow=10.2.0=py310hfae7ebd_0
  - pip=23.3.2=pyhd8ed1ab_0
  - pluggy=1.4.0=pyhd8ed1ab_0
  - pthread-stubs=0.4=h27ca646_1001
  - pyparsing=3.1.1=pyhd8ed1ab_0
  - pytest=8.0.0=pyhd8ed1ab_0
  - python=3.10.13=h2469fbe_1_cpython
  - python-dateutil=2.8.2=pyhd8ed1ab_0
  - python_abi=3.10=4_cp310
  - pytorch=2.2.0=py3.10_0
  - pyyaml=6.0.1=py310h2aa6e3c_1
  - readline=8.2=h92ec313_1
  - setuptools=69.0.3=pyhd8ed1ab_0
  - six=1.16.0=pyh6c4a22f_0
  - sympy=1.12=pypyh9d50eac_103
  - tk=8.6.13=h5083fa2_1
  - tomli=2.0.1=pyhd8ed1ab_0
  - tornado=6.3.3=py310h2aa6e3c_1
  - typing_extensions=4.9.0=pyha770c72_0
  - tzdata=2023d=h0c530f3_0
  - unicodedata2=15.1.0=py310h2aa6e3c_0
  - wheel=0.42.0=pyhd8ed1ab_0
  - xorg-libxau=1.0.11=hb547adb_0
  - xorg-libxdmcp=1.1.3=h27ca646_0
  - xz=5.2.6=h57fd34a_0
  - yaml=0.2.5=h3422bc3_2
  - zstd=1.5.5=h4f39d0f_0
  - pip:
      - absl-py==2.1.0
      - ale-py==0.8.1
      - annotated-types==0.6.0
      - antlr4-python3-runtime==4.9.3
      - appdirs==1.4.4
      - attrs==23.2.0
      - autorom==0.4.2
      - autorom-accept-rom-license==0.6.1
      - black==24.1.1
      - box2d-py==2.3.5
      - cfgv==3.4.0
      - charset-normalizer==3.3.2
      - click==8.1.7
      - cloudpickle==3.0.0
      - decorator==4.4.2
      - distlib==0.3.8
      - docker-pycreds==0.4.0
      - etils==1.6.0
      - farama-notifications==0.0.4
      - fsspec==2023.12.2
      - gitdb==4.0.11
      - gitpython==3.1.41
      - glfw==2.6.5
      - gymnasium==0.29.1
      - hydra-core==1.3.2
      - identify==2.5.33
      - idna==3.6
      - imageio-ffmpeg==0.4.9
      - importlib-resources==6.1.1
      - joblib==1.3.2
      - jsonref==1.1.0
      - jsonschema==4.21.1
      - jsonschema-specifications==2023.12.1
      - moviepy==1.0.3
      - mujoco==3.1.1
      - mypy-extensions==1.0.0
      - nodeenv==1.8.0
      - omegaconf==2.3.0
      - pathspec==0.12.1
      - platformdirs==4.2.0
      - pre-commit==3.6.0
      - proglog==0.1.10
      - protobuf==4.25.2
      - psutil==5.9.8
      - pydantic==2.6.0
      - pydantic-core==2.16.1
      - pygame==2.5.2
      - pyopengl==3.1.7
      - referencing==0.33.0
      - requests==2.31.0
      - rpds-py==0.17.1
      - scikit-learn==1.4.0
      - scipy==1.12.0
      - sentry-sdk==1.40.0
      - setproctitle==1.3.3
      - shimmy==0.2.1
      - smmap==5.0.1
      - sweeps==0.2.0
      - swig==4.1.1.post1
      - threadpoolctl==3.2.0
      - tqdm==4.66.1
      - urllib3==2.2.0
      - virtualenv==20.25.0
      - wandb==0.16.2
      - zipp==3.17.0

skandermoalla avatar Feb 11 '24 17:02 skandermoalla

Can you check #1900 whenever you have time?

vmoens avatar Feb 12 '24 08:02 vmoens

Almost solved. It works with Serial and Parallel Env, but somehow breaks when a Transformed env is added on top of the ParallelEnv.

from tensordict.nn import TensorDictModule
from torch import nn
from torchrl.envs import (
    ExplorationType,
    StepCounter,
    TransformedEnv,
    ParallelEnv,
)
from torchrl.envs.libs.gym import GymEnv
from torchrl.modules import OneHotCategorical, ProbabilisticActor

max_step = 10
n_env = 4
env_id = "MountainCar-v0"
device = "mps"


def build_cpu_single_env():
    env = GymEnv(env_id, device="cpu")
    env = TransformedEnv(env)
    env.append_transform(StepCounter(max_steps=max_step, truncated_key="foo"))
    return env


def build_actor(env):
    return ProbabilisticActor(
        module=TensorDictModule(
            nn.LazyLinear(env.action_spec.space.n),
            in_keys=["observation"],
            out_keys=["logits"],
        ),
        spec=env.action_spec,
        distribution_class=OneHotCategorical,
        in_keys=["logits"],
        default_interaction_type=ExplorationType.RANDOM,
    )


if __name__ == "__main__":
    # Works with both ParallelEnv and SerialEnv
    env = ParallelEnv(n_env, lambda: build_cpu_single_env(), device=device)
    # Breaks when adding a Transformed env on Parallel Env.
    env = TransformedEnv(env)
    policy_module = build_actor(env)
    policy_module.to(device)
    policy_module(env.reset())

    max_step = min(max_step, 200)
    for i in range(10):
        batches = env.rollout((max_step + 3), policy=policy_module, break_when_any_done=False)
        max_step_count = batches["next", "step_count"].max().item()
        # print(max_step_count)
        # print(batches["next", "step_count"])
        if max_step_count > max_step:
            print("Problem 1!")
            print(max_step_count)
            print(batches["next", "step_count"])
            break
        elif max_step_count < max_step:
            print("Problem 2!")
            print(max_step_count)
            print(batches["next", "step_count"])
            break
    else:
        print("No problem!")

Traceback (most recent call last):
  File "/Users/moalla/projects/open-source/TorchRL/tests/issue_env_device_parallel.py", line 48, in <module>
    batches = env.rollout((max_step + 3), policy=policy_module, break_when_any_done=False)
  File "/Users/moalla/projects/open-source/TorchRL/repos/rl/torchrl/envs/common.py", line 2395, in rollout
    tensordicts = self._rollout_nonstop(**kwargs)
  File "/Users/moalla/projects/open-source/TorchRL/repos/rl/torchrl/envs/common.py", line 2484, in _rollout_nonstop
    tensordict, tensordict_ = self.step_and_maybe_reset(tensordict_)
  File "/Users/moalla/projects/open-source/TorchRL/repos/rl/torchrl/envs/common.py", line 2554, in step_and_maybe_reset
    tensordict_ = self.reset(tensordict_)
  File "/Users/moalla/projects/open-source/TorchRL/repos/rl/torchrl/envs/common.py", line 2056, in reset
    tensordict_reset = self._reset(tensordict, **kwargs)
  File "/Users/moalla/projects/open-source/TorchRL/repos/rl/torchrl/envs/transforms/transforms.py", line 785, in _reset
    tensordict_reset = self.base_env._reset(tensordict=tensordict, **kwargs)
  File "/Users/moalla/projects/open-source/TorchRL/repos/rl/torchrl/envs/batched_envs.py", line 58, in decorated_fun
    return fun(self, *args, **kwargs)
  File "/Users/moalla/projects/open-source/TorchRL/repos/rl/torchrl/envs/batched_envs.py", line 1315, in _reset
    self.shared_tensordicts[i].apply_(
  File "/Users/moalla/projects/open-source/TorchRL/repos/tensordict/tensordict/base.py", line 3597, in apply_
    return self.apply(fn, *others, inplace=True, **kwargs)
  File "/Users/moalla/projects/open-source/TorchRL/repos/tensordict/tensordict/base.py", line 3692, in apply
    return self._apply_nest(
  File "/Users/moalla/projects/open-source/TorchRL/repos/tensordict/tensordict/_td.py", line 712, in _apply_nest
    item_trsf = fn(item, *_others)
  File "/Users/moalla/projects/open-source/TorchRL/repos/rl/torchrl/envs/batched_envs.py", line 1312, in tentative_update
    val.copy_(other)
RuntimeError: destOffset % 4 == 0 INTERNAL ASSERT FAILED at "/Users/runner/work/_temp/anaconda/conda-bld/pytorch_1704987091277/work/aten/src/ATen/native/mps/operations/Copy.mm":107, please report a bug to PyTorch. Unaligned blit request

skandermoalla avatar Feb 12 '24 10:02 skandermoalla

Actually there is another issue with ParallelEnv. The native truncation key is not faithful.

from tensordict.nn import TensorDictModule
from torch import nn
from torchrl.envs import ExplorationType, ParallelEnv, StepCounter, TransformedEnv
from torchrl.envs.libs.gym import GymEnv
from torchrl.modules import OneHotCategorical, ProbabilisticActor

max_step = 210
n_env = 4
env_id = "MountainCar-v0"
NATIVE_TRUNCATION = 200
device = "mps"
max_step = min(max_step, NATIVE_TRUNCATION)


def build_cpu_single_env():
    env = GymEnv(env_id, device="cpu")
    env = TransformedEnv(env)
    env.append_transform(StepCounter(max_steps=max_step))
    return env


def build_actor(env):
    return ProbabilisticActor(
        module=TensorDictModule(
            nn.LazyLinear(env.action_spec.space.n),
            in_keys=["observation"],
            out_keys=["logits"],
        ),
        spec=env.action_spec,
        distribution_class=OneHotCategorical,
        in_keys=["logits"],
        default_interaction_type=ExplorationType.RANDOM,
    )


if __name__ == "__main__":
    env = ParallelEnv(n_env, lambda: build_cpu_single_env(), device=device)
    # env = TransformedEnv(env)
    policy_module = build_actor(env)
    policy_module.to(device)
    policy_module(env.reset())

    for i in range(10):
        batches = env.rollout((max_step + 3), policy=policy_module, break_when_any_done=False)
        max_step_count = batches["next", "step_count"].max().item()
        if max_step_count > max_step:
            print(max_step_count)
            print(batches["next", "step_count"][:, -5:])
            print("Problem! Got higher than max step count.")
            break
        elif max_step_count < max_step:
            print(max_step_count)
            print(batches["next", "step_count"][:, -5:])
            print("Problem: Got less than max step count!")
            break
    else:
        print(batches["next", "step_count"][:, -5:])
        print("No problem!")
196
tensor([[[194],
         [195],
         [  1],
         [  2],
         [  3]],

        [[194],
         [195],
         [  1],
         [  2],
         [  3]],

        [[195],
         [196],
         [  1],
         [  2],
         [  3]],

        [[195],
         [196],
         [  1],
         [  2],
         [  3]]], device='mps:0')
Problem: Got less than max step count!

skandermoalla avatar Feb 12 '24 10:02 skandermoalla

Ok this will need for me to have access to an mps device then (won't have one for the upcoming 3w I think) :/

vmoens avatar Feb 12 '24 12:02 vmoens

This is now solved, right?

skandermoalla avatar Mar 27 '24 16:03 skandermoalla