[Bug] With `MultiSyncDataCollector`, `tensors` cannot be instantiated on CUDA in child processes.
Describe the bug
Despite applying the appropriate guards (mp.set_start_method('spawn'), if __name__ == "__main__"), using MultiSyncDataCollector
with the cuda device causes program to freeze.
To Reproduce
# BEFORE THE PROGRAM EVEN RUNS, FORCE THE START METHOD TO BE 'SPAWN'
from torch import multiprocessing as mp
mp.set_start_method("spawn", force = True)
from copy import deepcopy
import tqdm
import numpy as np
import math
import torch
from torch import nn
import torch.distributions as D
from torchrl.envs import check_env_specs, PettingZooEnv, ParallelEnv
from torchrl.modules import ProbabilisticActor
from torchrl.modules.models import MLP
from torchrl.modules.models.multiagent import MultiAgentNetBase
from torchrl.collectors import MultiSyncDataCollector
from tensordict.nn import TensorDictModule, TensorDictSequential, NormalParamExtractor
from torchrl.envs import TransformedEnv, Compose, RewardSum, StepCounter
from torchrl.record import CSVLogger, VideoRecorder, PixelRenderTransform
EPS = 1e-7
class SMACCNet(MultiAgentNetBase):
def __init__(self,
n_agent_inputs: int | None,
n_agent_outputs: int,
n_agents: int,
centralised: bool,
share_params: bool,
device = 'cpu',
activation_class = nn.Tanh,
**kwargs):
self.n_agents = n_agents
self.n_agent_inputs = n_agent_inputs
self.n_agent_outputs = n_agent_outputs
self.share_params = share_params
self.centralised = centralised
self.activation_class = activation_class
self.device = device
super().__init__(
n_agents=n_agents,
centralised=centralised,
share_params=share_params,
agent_dim=-2,
device = device,
**kwargs,
)
def _pre_forward_check(self, inputs):
if inputs.shape[-2] != self.n_agents:
raise ValueError(
f"Multi-agent network expected input with shape[-2]={self.n_agents},"
f" but got {inputs.shape}"
)
if self.centralised:
inputs = inputs.flatten(-2, -1)
return inputs
def init_net_params(self, net):
def init_layer_params(layer):
if isinstance(layer, nn.Linear):
weight_gain = 1. / (100 if layer.out_features == self.n_agent_outputs else 1)
torch.nn.init.xavier_uniform_(layer.weight, gain = weight_gain)
if 'bias' in layer.state_dict():
torch.nn.init.zeros_(layer.bias)
net.apply(init_layer_params)
return net
def _build_single_net(self, *, device, **kwargs):
n_agent_inputs = self.n_agent_inputs
if self.centralised and n_agent_inputs is not None:
n_agent_inputs = self.n_agent_inputs * self.n_agents
model = nn.Sequential(
nn.Linear(n_agent_inputs, 400),
self.activation_class(),
nn.Linear(400, 300),
self.activation_class(),
nn.Linear(300, self.n_agent_outputs)
).to(self.device) # We are not able to use MultiSyncDataCollector with the 'meta' device JUST YET!!!
model = self.init_net_params(model)
return model
class CustomTanhTransform(D.transforms.TanhTransform):
def _inverse(self, y):
# Yoinked from SB3!!!
"""
Inverse of Tanh
Taken from Pyro: https://github.com/pyro-ppl/pyro
0.5 * torch.log((1 + x ) / (1 - x))
"""
y = y.clamp(-1. + EPS, 1. - EPS)
return 0.5 * (y.log1p() - (-y).log1p())
def log_abs_det_jacobian(self, x, y):
# Yoinked from PyTorch TanhTransform!
'''
tl;dr log(1-tanh^2(x)) = log(sech^2(x))
= 2log(2/(e^x + e^(-x)))
= 2(log2 - log(e^x/(1 + e^(-2x)))
= 2(log2 - x - log(1 + e^(-2x)))
= 2(log2 - x - softplus(-2x))
'''
return 2.0 * (math.log(2.0) - x - nn.functional.softplus(-2.0 * x))
class TanhNormalStable(D.TransformedDistribution):
'''Numerically stable variant of TanhNormal. Employs clipping trick.'''
def __init__(self, loc, scale, event_dims = 1):
self._event_dims = event_dims
self._t = [
CustomTanhTransform()
]
self.update(loc, scale)
def log_prob(self, value):
"""
Scores the sample by inverting the transform(s) and computing the score
using the score of the base distribution and the log abs det jacobian.
"""
if self._validate_args:
self._validate_sample(value)
event_dim = len(self.event_shape)
log_prob = 0.0
y = value
for transform in reversed(self.transforms):
x = transform.inv(y)
event_dim += transform.domain.event_dim - transform.codomain.event_dim
log_prob = log_prob - D.utils._sum_rightmost(
transform.log_abs_det_jacobian(x, y),
event_dim - transform.domain.event_dim,
)
y = x
log_prob = log_prob + D.utils._sum_rightmost(
self.base_dist.log_prob(y), event_dim - len(self.base_dist.event_shape)
)
log_prob = torch.clamp(log_prob, min = math.log10(EPS))
return log_prob
def update(self, loc: torch.Tensor, scale: torch.Tensor) -> None:
self.loc = loc
self.scale = scale
if (
hasattr(self, "base_dist")
and (self.base_dist.base_dist.loc.shape == self.loc.shape)
and (self.base_dist.base_dist.scale.shape == self.scale.shape)
):
self.base_dist.base_dist.loc = self.loc
self.base_dist.base_dist.scale = self.scale
else:
base = D.Independent(D.Normal(self.loc, self.scale), self._event_dims)
super().__init__(base, self._t)
@property
def mode(self):
m = self.base_dist.base_dist.mean
for t in self.transforms:
m = t(m)
return m
# Main Function
if __name__ == "__main__":
NUM_AGENTS = 3
NUM_CRITICS = 2
NUM_EXPLORE_WORKERS = 1
EXPLORATION_STEPS = 30000
MAX_EPISODE_STEPS = 1000
DEVICE = "cuda:0"
REPLAY_BUFFER_SIZE = int(1e6)
VALUE_GAMMA = 0.99
MAX_GRAD_NORM = 1.0
BATCH_SIZE = 512
LR = 3e-4
UPDATE_STEPS_PER_EXPLORATION = 1500
WARMUP_STEPS = int(3e5)
TRAIN_TIMESTEPS = int(1e7)
EVAL_INTERVAL = 10
EVAL_EPISODES = 20
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)
def env_fn(mode, parallel = True, rew_scale = True, killswitch = False):
if rew_scale:
terminate_scale = -3.0
forward_scale = 2.5
fall_scale = -3.0
else:
# Use the defaults from PZ
terminate_scale, forward_scale, fall_scale = -100.0, 1.0, -10.0
def base_env_fn():
return PettingZooEnv(task = "multiwalker_v9",
parallel = True,
seed = 42,
n_walkers = NUM_AGENTS,
terminate_reward = terminate_scale,
forward_reward = forward_scale,
fall_reward = fall_scale,
shared_reward = False,
max_cycles = MAX_EPISODE_STEPS,
render_mode = mode,
device = DEVICE
)
env = base_env_fn # noqa: E731
def env_with_transforms():
init_env = env()
init_env = TransformedEnv(init_env, Compose(
StepCounter(max_steps = MAX_EPISODE_STEPS),
RewardSum(
in_keys = [init_env.reward_key for _ in range(NUM_AGENTS)],
out_keys = [("walker", "episode_reward")] * NUM_AGENTS,
reset_keys = ["_reset"] * NUM_AGENTS
),
)
)
if killswitch:
breakpoint()
return init_env
return env_with_transforms
train_env = env_fn(None, parallel = False)()
if train_env.is_closed:
train_env.start()
def create_eval_env(tag = "rendered"):
eval_env = env_fn("rgb_array", parallel = False, rew_scale = False)()
video_recorder = VideoRecorder(
CSVLogger("multiwalker-toy-test", video_format = "mp4"),
tag = tag,
in_keys = ["pixels_record"]
)
# Call the parent's render function
eval_env.append_transform(PixelRenderTransform(out_keys = ["pixels_record"]))
eval_env.append_transform(video_recorder)
if eval_env.is_closed:
eval_env.start()
return eval_env
check_env_specs(train_env)
obs_dim = train_env.full_observation_spec["walker", "observation"].shape[-1]
action_dim = train_env.full_action_spec["walker", "action"].shape[-1]
policy_net = nn.Sequential(
SMACCNet(n_agent_inputs = obs_dim,
n_agent_outputs = 2 * action_dim,
n_agents = NUM_AGENTS,
centralised = False,
share_params = True,
device = DEVICE,
activation_class = nn.LeakyReLU,
),
NormalParamExtractor(),
)
critic_net = SMACCNet(n_agent_inputs = obs_dim + action_dim,
n_agent_outputs = 1,
n_agents = NUM_AGENTS,
centralised = True,
share_params = True,
device = DEVICE,
activation_class = nn.LeakyReLU,
)
policy_net_td_module = TensorDictModule(module = policy_net,
in_keys = [("walker", "observation")],
out_keys = [("walker", "loc"), ("walker", "scale")]
)
obs_act_module = TensorDictModule(lambda obs, act: torch.cat([obs, act], dim = -1),
in_keys = [("walker", "observation"), ("walker", "action")],
out_keys = [("walker", "obs_act")]
)
critic_net_td_module = TensorDictModule(module = critic_net,
in_keys = [("walker", "obs_act")],
out_keys = [("walker", "state_action_value")]
)
# Attach our raw policy network to a probabilistic actor
policy_actor = ProbabilisticActor(
module = policy_net_td_module,
spec = train_env.full_action_spec["walker", "action"],
in_keys = [("walker", "loc"), ("walker", "scale")],
out_keys = [("walker", "action")],
distribution_class = TanhNormalStable,
return_log_prob = True,
)
with torch.no_grad():
fake_td = train_env.fake_tensordict()
policy_actor(fake_td)
critic_actor = TensorDictSequential(
obs_act_module, critic_net_td_module
)
with torch.no_grad():
reset_obs = train_env.reset()
reset_obs_clean = deepcopy(reset_obs)
action = policy_actor(reset_obs)
state_action_value = critic_actor(action)
reset_obs = train_env.reset()
reset_obs["walker", "action"] = torch.zeros((*reset_obs["walker", "observation"].shape[:-1], action_dim))
train_env.rand_action(reset_obs)
action = train_env.step(reset_obs)
print("As you can see, spawning a single environment on the main process is absolutely unproblematic.")
collector = MultiSyncDataCollector(
[env_fn(None, parallel = False, killswitch = True) for _ in range(NUM_EXPLORE_WORKERS)],
policy = policy_actor, # the explora
frames_per_batch = BATCH_SIZE,
max_frames_per_traj = 0,
total_frames = TRAIN_TIMESTEPS,
device = DEVICE,
reset_at_each_iter = False
)
for i, tensordict in (pbar := tqdm.tqdm(enumerate(collector), total = TRAIN_TIMESTEPS)):
pbar.write("Hey Hey!!! :D")
collector.shutdown()
train_env.close()
Execution output:
<program executes as usual>
As you can see, spawning a single environment on the main process is absolutely unproblematic.
<Program freezes indefinitely>
Terminating the program gives this traceback:
Traceback (most recent call last):
File "/mnt/c/Users/N00bcak/Desktop/programming/drones_go_brr/scripts/torchrl_cuda_hangs.py", line 326, in <module>
collector = MultiSyncDataCollector(
^^^^^^^^^^^^^^^^^^^^^^^
File "/home/n00bcak/venvs/torchrl-3.11/lib/python3.11/site-packages/torchrl/collectors/collectors.py", line 1516, in __init__
self._run_processes()
File "/home/n00bcak/venvs/torchrl-3.11/lib/python3.11/site-packages/torchrl/collectors/collectors.py", line 1690, in _run_processes
msg = pipe_parent.recv()
^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/multiprocessing/connection.py", line 250, in recv
buf = self._recv_bytes()
^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/multiprocessing/connection.py", line 430, in _recv_bytes
buf = self._recv(4)
^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/multiprocessing/connection.py", line 395, in _recv
chunk = read(handle, remaining)
^^^^^^^^^^^^^^^^^^^^^^^
KeyboardInterrupt
^CException ignored in atexit callback: <function _exit_function at 0x7f4151e12e80>
Traceback (most recent call last):
File "/usr/local/lib/python3.11/multiprocessing/util.py", line 360, in _exit_function
p.join()
File "/usr/local/lib/python3.11/multiprocessing/process.py", line 149, in join
res = self._popen.wait(timeout)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/multiprocessing/popen_fork.py", line 43, in wait
return self.poll(os.WNOHANG if timeout == 0.0 else 0)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/multiprocessing/popen_fork.py", line 27, in poll
pid, sts = os.waitpid(self.pid, flag)
^^^^^^^^^^^^^^^^^^^^^^^^^^
KeyboardInterrupt:
Expected behavior
After printing "As you can see, spawning a single environment on the main process is absolutely unproblematic.", program progresses into the collector iterable and prints "Hey Hey!!! :D" repeatedly.
System info
Describe the characteristic of your environment:
- Describe how the library was installed (pip, source, ...)
- Python version
- Versions of any other relevant libraries
>>> import torchrl, numpy, sys
>>> print(torchrl.__version__, numpy.__version__, sys.version, sys.platform)
0.4.0 1.26.4 3.11.9 (main, Jun 5 2024, 10:27:27) [GCC 12.2.0] linux
Additional context
Problem was encountered as part of an effort to spawn multiple environments on the GPU. Any pointers in this direction greatly appreciated.
Proof of issue with tensors
By adding a killswitch into env_fn in various positions, we can make the following observations:
Code (No tensor defined yet)
def env_fn(mode, parallel = True, rew_scale = True, killswitch = False):
if rew_scale:
terminate_scale = -3.0
forward_scale = 2.5
fall_scale = -3.0
else:
# Use the defaults from PZ
terminate_scale, forward_scale, fall_scale = -100.0, 1.0, -10.0
def base_env_fn():
return PettingZooEnv(task = "multiwalker_v9",
parallel = True,
seed = 42,
n_walkers = NUM_AGENTS,
terminate_reward = terminate_scale,
forward_reward = forward_scale,
fall_reward = fall_scale,
shared_reward = False,
max_cycles = MAX_EPISODE_STEPS,
render_mode = mode,
device = DEVICE
)
env = base_env_fn # noqa: E731
def env_with_transforms():
if killswitch:
breakpoint() # Killswitch before env initialization
init_env = env()
init_env = TransformedEnv(init_env, Compose(
StepCounter(max_steps = MAX_EPISODE_STEPS),
RewardSum(
in_keys = [init_env.reward_key for _ in range(NUM_AGENTS)],
out_keys = [("walker", "episode_reward")] * NUM_AGENTS,
reset_keys = ["_reset"] * NUM_AGENTS
),
)
)
return init_env
return env_with_transforms
Result: Program crashes as expected when hitting a breakpoint with child process.
Process _ProcessNoWarn-1:
Traceback (most recent call last):
File "/usr/local/lib/python3.11/multiprocessing/process.py", line 314, in _bootstrap
self.run()
File "/home/n00bcak/venvs/torchrl-3.11/lib/python3.11/site-packages/torchrl/_utils.py", line 668, in run
return mp.Process.run(self, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/multiprocessing/process.py", line 108, in run
self._target(*self._args, **self._kwargs)
File "/home/n00bcak/venvs/torchrl-3.11/lib/python3.11/site-packages/torchrl/collectors/collectors.py", line 2653, in _main_async_collector
inner_collector = SyncDataCollector(
^^^^^^^^^^^^^^^^^^
File "/home/n00bcak/venvs/torchrl-3.11/lib/python3.11/site-packages/torchrl/collectors/collectors.py", line 450, in __init__
self.closed = True
^^^^
File "/usr/local/lib/python3.11/bdb.py", line 90, in trace_dispatch
return self.dispatch_line(frame)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/bdb.py", line 115, in dispatch_line
if self.quitting: raise BdbQuit
^^^^^^^^^^^^^
bdb.BdbQuit
Traceback (most recent call last):
File "/mnt/c/Users/N00bcak/Desktop/programming/drones_go_brr/scripts/torchrl_cuda_hangs.py", line 326, in <module>
collector = MultiSyncDataCollector(
^^^^^^^^^^^^^^^^^^^^^^^
File "/home/n00bcak/venvs/torchrl-3.11/lib/python3.11/site-packages/torchrl/collectors/collectors.py", line 1518, in __init__
self._run_processes()
File "/home/n00bcak/venvs/torchrl-3.11/lib/python3.11/site-packages/torchrl/collectors/collectors.py", line 1692, in _run_processes
msg = pipe_parent.recv()
^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/multiprocessing/connection.py", line 250, in recv
buf = self._recv_bytes()
^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/multiprocessing/connection.py", line 430, in _recv_bytes
buf = self._recv(4)
^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/multiprocessing/connection.py", line 399, in _recv
raise EOFError
EOFError
[W CudaIPCTypes.cpp:16] Producer process has been terminated before all shared CUDA tensors released. See Note [Sharing CUDA tensors]
Code: Insert CUDA tensor declaration in killswitch clause
if killswitch:
torchy_mctorchface = torch.tensor([1,2,3,4,5], device = 'cuda:0')
breakpoint()
Result: Program hangs indefinitely.
PS
Since error relates to tensors, would it be a good idea to rope in PyTorch devs?
Checklist
- [x] I have checked that there is no similar issue in the repo (required)
- [x] I have read the documentation (required)
- [x] I have provided a minimal working example to reproduce the bug (required)
I did not try the additional context block, but running the code above on my machine without these lines works perfectly fine (the Hey Hey is displayed as expected)
if killswitch:
breakpoint()
If I don't remove that block the program fails on my Python 3.10 env (even if the breakpoint is never reached).
Some further things we can look at to debug:
What env variable are you setting, if any? What cuda version / pytorch version do you have? Does the cuda of your PT match the cuda on the machine?
tl;dr seems to either be a WSL2-Debian OR a Python 3.11 quirk. Very interesting.
Part 1
My bad, I should have specified that I was on WSL2-Debian.
Here's some information regarding that:
Debian Version
> python3 -c "import sys, torch, torchrl, tensordict; print(sys.version, torch.__version__, torchrl.__version__, ten
sordict.__version__)"
3.11.9 (main, Jun 5 2024, 10:27:27) [GCC 12.2.0] 2.3.0+cu121 0.4.0 0.4.0
> lsb_release -a
No LSB modules are available.
Distributor ID: Debian
Description: Debian GNU/Linux 12 (bookworm)
Release: 12
Codename: bookworm
PS C:\Windows\system32> (get-item C:\windows\system32\wsl.exe).VersionInfo.FileVersion
10.0.19041.3636 (WinBuild.160101.0800)
Part 2
Strange. I am now using Python 3.10 on a different (single-boot Ubuntu) machine, but I cannot reproduce the bug either.
This is my Python environment:
Python 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> import torchrl
>>> import tensordict
>>> torch.__version__, torchrl.__version__, tensordict.__version__
('2.3.0+cu121', '0.4.0', '0.4.0')
What cuda version / pytorch version do you have? Does the cuda of your PT match the cuda on the machine?
Both of my machines use the CUDA that comes with PyTorch.
What env variable are you setting, if any?
The offending files do not have any special environment variables set.