[BUG] Numerical Instability issues with `torchrl.modules.TanhNormal`
Describe the bug
When training on PettingZoo/MultiWalker-v9 with Multi-Agent Soft Actor-Critic, all losses (loss_actor, loss_qvalue, loss_alpha) explode after ~1M environment steps at most.
This phenomenon occurs regardless of (reasonable) hyperparameter and gradient clipping threshold choice.
To Reproduce
from copy import deepcopy
import tqdm
import numpy as np
from gymnasium.spaces import Box
import logging
import math
import torch
from torch import nn
import torch.distributions as D
from torchrl.data.replay_buffers import TensorDictReplayBuffer
from torchrl.data.replay_buffers.samplers import RandomSampler
from torchrl.data.replay_buffers.storages import LazyMemmapStorage
from torchrl.envs import (
check_env_specs,
PettingZooEnv,
ParallelEnv,
GymEnv
)
from torchrl.modules import AdditiveGaussianWrapper, ProbabilisticActor
from torchrl.modules.models import MLP
from torchrl.modules.models.multiagent import (
MultiAgentMLP,
MultiAgentNetBase
)
from torchrl.collectors import SyncDataCollector, MultiSyncDataCollector, RandomPolicy
from torchrl.objectives import SACLoss, SoftUpdate
from tensordict import TensorDict
from tensordict.nn import TensorDictModule, TensorDictSequential, NormalParamExtractor
from torchrl.envs.utils import ExplorationType, set_exploration_type
from torchrl.envs import EnvCreator, TransformedEnv, Compose, Transform, RewardSum, ObservationNorm, StepCounter
from torchrl.record import CSVLogger, VideoRecorder, PixelRenderTransform
import multiprocessing as mp
EPS = 1e-7
class SMACCNet(MultiAgentNetBase):
'''
This is an MLP policy network for MultiAgent SAC.
This is just a more limited version of MultiAgentMLP.
(https://pytorch.org/rl/main/_modules/torchrl/modules/models/multiagent.html)
'''
def __init__(self,
n_agent_inputs: int | None,
n_agent_outputs: int,
n_agents: int,
centralised: bool,
share_params: bool,
device = 'cpu',
activation_class = nn.Tanh,
**kwargs):
self.n_agents = n_agents
self.n_agent_inputs = n_agent_inputs
self.n_agent_outputs = n_agent_outputs
self.share_params = share_params
self.centralised = centralised
self.activation_class = activation_class
self.device = device
super().__init__(
n_agents=n_agents,
centralised=centralised,
share_params=share_params,
agent_dim=-2,
device = device,
**kwargs,
)
# Copied over from MultiAgentMLP.
def _pre_forward_check(self, inputs):
if inputs.shape[-2] != self.n_agents:
raise ValueError(
f"Multi-agent network expected input with shape[-2]={self.n_agents},"
f" but got {inputs.shape}"
)
# If the model is centralized, agents have full observability
if self.centralised:
inputs = inputs.flatten(-2, -1)
return inputs
def init_net_params(self, net):
def init_layer_params(layer):
if isinstance(layer, nn.Linear):
weight_gain = 1. / (100 if layer.out_features == self.n_agent_outputs else 1)
torch.nn.init.xavier_uniform_(layer.weight, gain = weight_gain)
if 'bias' in layer.state_dict():
torch.nn.init.zeros_(layer.bias)
net.apply(init_layer_params)
return net
def _build_single_net(self, *, device, **kwargs):
n_agent_inputs = self.n_agent_inputs
if self.centralised and n_agent_inputs is not None:
n_agent_inputs = self.n_agent_inputs * self.n_agents
model = nn.Sequential(
nn.Linear(n_agent_inputs, 400),
self.activation_class(),
nn.Linear(400, 300),
self.activation_class(),
nn.Linear(300, self.n_agent_outputs)
).to(self.device) # Bandaid fix to use MultiSyncDataCollector
model = self.init_net_params(model)
return model
class TqdmLoggingHandler(logging.Handler):
def __init__(self, level=logging.NOTSET):
super().__init__(level)
def emit(self, record):
try:
msg = self.format(record)
tqdm.tqdm.write(msg)
self.flush()
except Exception:
self.handleError(record)
# Main Function
if __name__ == "__main__":
logging.basicConfig(level = logging.INFO)
logger = logging.getLogger(__name__)
logger.propagate = False
logger.addHandler(TqdmLoggingHandler())
mp.set_start_method("spawn", force = True)
NUM_AGENTS = 3
NUM_CRITICS = 2
NUM_EXPLORE_WORKERS = 8
EXPLORATION_STEPS = 30000
MAX_EPISODE_STEPS = 1000
DEVICE = "cuda"
REPLAY_BUFFER_SIZE = int(1e6)
VALUE_GAMMA = 0.99
MAX_GRAD_NORM = 1.0
BATCH_SIZE = 256
LR = 1e-4
UPDATE_STEPS_PER_EXPLORATION = 1500
WARMUP_STEPS = 0 #int(2e5)
TRAIN_TIMESTEPS = int(1e7)
EVAL_INTERVAL = 1 #int(9e4 // EXPLORATION_STEPS) # Every 500k steps or so, evaluate once.
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)
# https://pytorch.org/rl/stable/tutorials/multiagent_competitive_ddpg.html
# More tutorials: https://pytorch.org/tutorials/advanced/pendulum.html
# Toy test: https://pettingzoo.farama.org/environments/sisl/multiwalker/def env_fn(mode, parallel = True):
def base_env_fn():
return PettingZooEnv(task = "multiwalker_v9",
parallel = True,
seed = 42,
n_walkers = NUM_AGENTS,
terminate_reward = -5.0,
forward_reward = 1.0,
fall_reward = -1.0,
shared_reward = False,
max_cycles = MAX_EPISODE_STEPS,
render_mode = mode,
device = "cpu"
)
if parallel:
# Don't use.
# https://discuss.pytorch.org/t/pettingzoo-trouble-running-multiple-marl-environments-in-parallel/203706/
env = lambda: ParallelEnv(num_workers = 4, # noqa: E731
create_env_fn = base_env_fn,
device = "cpu",
mp_start_method = "spawn",
serial_for_single = True
)
else:
env = base_env_fn # noqa: E731
def env_with_transforms():
# dummy_env = base_env_fn()
# dummy_obs_transform = ObservationNorm(in_keys = [("walker", "observation")], standard_normal = True)
# dummy_env = TransformedEnv(dummy_env, dummy_obs_transform)
# dummy_obs_transform.init_stats(10000)
init_env = env()
# obs_transform = ObservationNorm(loc = dummy_obs_transform.loc + EPS,
# scale = dummy_obs_transform.scale + EPS,
# in_keys = [("walker", "observation")],
# standard_normal = True
# )
init_env = TransformedEnv(init_env, Compose(
StepCounter(max_steps = MAX_EPISODE_STEPS),
RewardSum(
in_keys = [init_env.reward_key for _ in range(NUM_AGENTS)],
out_keys = [("walker", "episode_reward")] * NUM_AGENTS,
reset_keys = ["_reset"] * NUM_AGENTS
),
# obs_transform
)
)
# del dummy_env, dummy_obs_transform
return init_env
return env_with_transforms
train_env = env_fn(None, parallel = False)()
if train_env.is_closed:
train_env.start()
eval_env = env_fn("rgb_array", parallel = False)()
video_recorder = VideoRecorder(
CSVLogger("multiwalker-toy-test", video_format = "mp4"),
tag = "rendered",
in_keys = ["pixels_record"]
)
# Call the parent's render function
eval_env.append_transform(PixelRenderTransform(out_keys = ["pixels_record"]))
eval_env.append_transform(video_recorder)
if eval_env.is_closed:
eval_env.start()
check_env_specs(train_env)
check_env_specs(eval_env)
print(f"Action: {train_env.full_action_spec}, Reward: {train_env.full_reward_spec}, Done: {train_env.full_done_spec}, Observation: {train_env.full_observation_spec}")
print(f"group_map: {train_env.group_map}")
print(f"Action: {train_env.action_keys}, Reward: {train_env.reward_keys}, Done: {train_env.done_keys}")
# NOTE: The input and output spaces to be fed in are on a PER-AGENT basis.
# Basically, if you have 16 agents observing 3D velocity and outputting speed (the magnitude),
# n_agent_inputs = 3, n_agent_outputs = 1.
obs_dim = train_env.full_observation_spec["walker", "observation"].shape[-1]
action_dim = train_env.full_action_spec["walker", "action"].shape[-1]
policy_net = nn.Sequential(
SMACCNet(n_agent_inputs = obs_dim,
n_agent_outputs = 2 * action_dim,
n_agents = NUM_AGENTS,
centralised = False,
share_params = True,
device = "cpu",
activation_class = nn.LeakyReLU,
),
NormalParamExtractor(),
)
critic_net = SMACCNet(n_agent_inputs = obs_dim + action_dim,
n_agent_outputs = 1,
n_agents = NUM_AGENTS,
centralised = True,
share_params = True,
device = "cpu",
activation_class = nn.LeakyReLU,
)
# Hook our networks to TensorDictModules so they can be a part of the TensorDict pipeline...
policy_net_td_module = TensorDictModule(module = policy_net,
in_keys = [("walker", "observation")],
# NOTE: These outputs must match with the parameter names of the
# distribution you are using!
out_keys = [("walker", "loc"), ("walker", "scale")]
)
obs_act_module = TensorDictModule(lambda obs, act: torch.cat([obs, act], dim = -1),
in_keys = [("walker", "observation"), ("walker", "action")],
out_keys = [("walker", "obs_act")]
)
critic_net_td_module = TensorDictModule(module = critic_net,
in_keys = [("walker", "obs_act")],
out_keys = [("walker", "state_action_value")]
)
# Attach our raw policy network to a probabilistic actor
policy_actor = ProbabilisticActor(
module = policy_net_td_module,
spec = train_env.full_action_spec["walker", "action"],
in_keys = [("walker", "loc"), ("walker", "scale")],
out_keys = [("walker", "action")],
# TanhNormal is based off of pytorch, which as far as we know,
# implements a numerically stable log det jacobian.
distribution_class = TanhNormal,
distribution_kwargs = {
"min": train_env.full_action_spec["walker", "action"].space.low,
"max": train_env.full_action_spec["walker", "action"].space.high,
},
return_log_prob = True,
)
with torch.no_grad():
fake_td = train_env.fake_tensordict()
policy_actor(fake_td)
dora = AdditiveGaussianWrapper(
policy = policy_actor,
action_key = ("walker", "action"),
sigma_init = 0.3,
sigma_end = 0.1,
annealing_num_steps = TRAIN_TIMESTEPS // 2
)
critic_actor = TensorDictSequential(
obs_act_module, critic_net_td_module
)
collector = MultiSyncDataCollector(
[env_fn(None, parallel = False) for _ in range(NUM_EXPLORE_WORKERS)],
policy = dora,
frames_per_batch = BATCH_SIZE,
max_frames_per_traj = 0,
total_frames = TRAIN_TIMESTEPS,
device = "cpu",
reset_at_each_iter = False
)
replay_buffer = TensorDictReplayBuffer(
storage = LazyMemmapStorage(
REPLAY_BUFFER_SIZE, device = "cpu",
), # We will store up to memory_size multi-agent transitions
sampler = RandomSampler(),
batch_size = BATCH_SIZE, # We will sample batches of this size
)
sac_loss = SACLoss(policy_actor.to(DEVICE),
qvalue_network = critic_actor.to(DEVICE),
num_qvalue_nets = 2,
loss_function = "l2",
delay_qvalue = True,
alpha_init = 0.1
)
sac_loss.set_keys(
action = ("walker", "action"),
state_action_value = ("walker", "state_action_value"),
reward = ("walker", "reward"),
done = ("walker", "done"),
terminated = ("walker", "terminated"),
)
sac_loss.make_value_estimator(gamma = VALUE_GAMMA)
polyak_updater = SoftUpdate(sac_loss, tau = 0.005)
critic_params = list(sac_loss.qvalue_network_params.flatten_keys().values())
actor_params = list(sac_loss.actor_network_params.flatten_keys().values())
optimizer_actor = torch.optim.Adam(
actor_params,
lr = LR,
weight_decay = 5e-4,
eps = EPS,
betas = (0.9, 0.98)
)
optimizer_critic = torch.optim.Adam(
critic_params,
lr = LR,
weight_decay = 5e-4,
eps = EPS,
betas = (0.9, 0.98)
)
optimizer_alpha = torch.optim.Adam(
[sac_loss.log_alpha],
lr = LR,
eps = EPS,
betas = (0.9, 0.98)
)
# breakpoint()
num_frames = 0
pbar = tqdm.tqdm(total = TRAIN_TIMESTEPS)
total_frames = 0
backprop_ctr = 0
train_rews, ep_lengths = [], []
EXPLORATION_BATCHES = EXPLORATION_STEPS // BATCH_SIZE
for i, tensordict in enumerate(collector):
collector.update_policy_weights_()
pbar.update(tensordict.numel())
tensordict = tensordict.reshape(-1)
current_frames = tensordict.numel()
# Add to replay buffer
replay_buffer.extend(tensordict.cpu())
total_frames += current_frames
backprop_ctr += 1
# Optimization steps
if total_frames >= WARMUP_STEPS and backprop_ctr > EXPLORATION_BATCHES:
backprop_ctr = 0
losses = TensorDict({}, batch_size = [UPDATE_STEPS_PER_EXPLORATION])
alphas = TensorDict({}, batch_size = [UPDATE_STEPS_PER_EXPLORATION])
for j in range(UPDATE_STEPS_PER_EXPLORATION):
# Sample from replay buffer
sampled_tensordict = replay_buffer.sample()
if str(sampled_tensordict.device) != DEVICE:
sampled_tensordict = sampled_tensordict.to(DEVICE, non_blocking = False)
else:
sampled_tensordict = sampled_tensordict.clone()
try:
# Compute loss
loss_td = sac_loss(sampled_tensordict)
except KeyError:
raise Exception(f"wtf {sampled_tensordict}\n{obs_act_module(sampled_tensordict)['walker', 'obs_act']}")
actor_loss = loss_td["loss_actor"]
q_loss = loss_td["loss_qvalue"]
alpha_loss = loss_td["loss_alpha"]
# Update actor
optimizer_actor.zero_grad()
actor_loss.backward()
actor_grad_norm = torch.nn.utils.clip_grad_norm_(actor_params, max_norm = MAX_GRAD_NORM)
optimizer_actor.step()
# Update critic
optimizer_critic.zero_grad()
q_loss.backward()
q_grad_norm = torch.nn.utils.clip_grad_norm_(critic_params, max_norm = MAX_GRAD_NORM)
optimizer_critic.step()
# Update alpha
optimizer_alpha.zero_grad()
alpha_loss.backward()
alpha_grad_norm = torch.nn.utils.clip_grad_norm_([sac_loss.log_alpha], max_norm = MAX_GRAD_NORM)
optimizer_alpha.step()
losses[j] = loss_td.select(
"loss_actor", "loss_qvalue", "loss_alpha"
).detach()
alphas[j] = loss_td.select("alpha")
# Update qnet_target params
polyak_updater.step()
# Some other stuff I ripped out from https://github.com/pytorch/rl/blob/main/sota-implementations/sac/sac.py
episode_end = (
tensordict["next", "done"]
if tensordict["next", "done"].any()
else tensordict["next", "truncated"]
)
opening_banner = "-" * 10 + f" Batch {i + 1} " + "-" * 10
def get_mean(src, key):
return src.get(key).mean().item()
logger.info(opening_banner)
logger.info(f"Average Actor Loss: {get_mean(losses, 'loss_actor')}")
logger.info(f"Average Q Loss: {get_mean(losses, 'loss_qvalue')}")
logger.info(f"Average Alpha: {get_mean(alphas, 'alpha')} (Loss: {get_mean(losses, 'loss_alpha')})")
logger.info("-" * len(opening_banner))
ep_length = tensordict['next', 'step_count'][episode_end].to(dtype = torch.float64)
if ep_length.numel():
ep_lengths.append(ep_length.mean().item())
agent_terminated = torch.stack(
[
tensordict["next", "walker", "done"][:, agent_id, 0]
if tensordict["next", "walker", "done"][:, agent_id, 0].any()
else tensordict["next", "walker", "truncated"][:, agent_id, 0]
for agent_id in range(NUM_AGENTS)
],
dim = 1
)
train_reward = tensordict['next', 'walker', 'episode_reward'][agent_terminated]
if train_reward.numel():
train_rews.append(train_reward.mean().item())
if not ((i + 1) % (EVAL_INTERVAL * EXPLORATION_BATCHES)):
logger.info(
f"Mean Train Reward Across Past {EVAL_INTERVAL} Collections: " +
(
f"{sum(train_rews) / len(train_rews)}"
if len(train_rews)
else f"NA (Training starts @ {WARMUP_STEPS} steps)"
)
)
with set_exploration_type(ExplorationType.MODE), torch.no_grad():
eval_rollout = eval_env.rollout(
MAX_EPISODE_STEPS,
policy_actor,
auto_cast_to_device=True,
break_when_any_done=True,
)
mean_eval_length = eval_rollout["next", "step_count"][-1].to(dtype = torch.float64).mean().item()
logger.info(f"Mean Eval Reward: {eval_rollout['next', 'walker', 'episode_reward'][-1].mean().item()}")
logger.info(f"Eval Length: {mean_eval_length}")
ep_reward_list = []
train_rews = []
eval_env.transform.dump()
collector.shutdown()
train_env.close()
Expected behavior
Loss values stay within ~ +/- 10^2 throughout training and do not increase to ~ +/- 10^x where x >> 1.
System info
>>> import torchrl, numpy, sys
>>> print(f"TorchRL: {torchrl.__version__}\nNumPy: {numpy.__version__}\nPython3 Ver: {sys.version}\nPlatform: {sys.platform}")
TorchRL: 0.4.0
NumPy: 1.25.0
Python3 Ver: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0]
Platform: linux
> lsb_release -a
Distributor ID: Ubuntu
Description: Ubuntu 22.04.3 LTS
Release: 22.04
Codename: jammy
Reason and Possible fixes
Though the environment's observation space is not normalized and carries unbounded entries, the issue does not appear to entirely arise from the poor observation scaling, since adding a torchrl.envs.ObservationNorm does not mitigate the issue.
Debugging reveals that unusually large and negative values for log_prob are somehow being fed into the SACLoss calculations from the reimplementation of torch.distributions.transforms.TanhTransform.
https://github.com/pytorch/rl/blob/3e6cb8419df56d9263d1daa48f9c3be5f01eaea6/torchrl/modules/distributions/continuous.py#L289-L382
Since this reimplementation does not change much from the original TanhTransform, it is plausible that the reimplementation is NOT the root cause of the error. Nevertheless, replacing the reimplementation with an alternative variant gets rid of the issue altogether:
class CustomTanhTransform(D.transforms.TanhTransform):
def _inverse(self, y):
# from stable_baselines3's `common.distributions.TanhBijector`
"""
Inverse of Tanh
Taken from Pyro: https://github.com/pyro-ppl/pyro
0.5 * torch.log((1 + x ) / (1 - x))
"""
y = y.clamp(-1. + EPS, 1. - EPS)
return 0.5 * (y.log1p() - (-y).log1p())
def log_abs_det_jacobian(self, x, y):
# From PyTorch `TanhTransform`
'''
tl;dr log(1-tanh^2(x)) = log(sech^2(x))
= 2log(2/(e^x + e^(-x)))
= 2(log2 - log(e^x/(1 + e^(-2x)))
= 2(log2 - x - log(1 + e^(-2x)))
= 2(log2 - x - softplus(-2x))
'''
return 2.0 * (math.log(2.0) - x - nn.functional.softplus(-2.0 * x))
class TanhNormalStable(D.TransformedDistribution):
def __init__(self, loc, scale, event_dims = 1):
self._event_dims = event_dims
self._t = [
CustomTanhTransform()
]
self.update(loc, scale)
def log_prob(self, value):
"""
Scores the sample by inverting the transform(s) and computing the score
using the score of the base distribution and the log abs det jacobian.
"""
if self._validate_args:
self._validate_sample(value)
event_dim = len(self.event_shape)
log_prob = 0.0
y = value
for transform in reversed(self.transforms):
x = transform.inv(y)
event_dim += transform.domain.event_dim - transform.codomain.event_dim
log_prob = log_prob - D.utils._sum_rightmost(
transform.log_abs_det_jacobian(x, y),
event_dim - transform.domain.event_dim,
)
y = x
log_prob = log_prob + D.utils._sum_rightmost(
self.base_dist.log_prob(y), event_dim - len(self.base_dist.event_shape)
)
log_prob = torch.clamp(log_prob, min = math.log10(EPS)) # <- **CLAMPING THIS SEEMS TO RESOLVE THE ISSUE**
return log_prob
def update(self, loc: torch.Tensor, scale: torch.Tensor) -> None:
self.loc = loc
self.scale = scale
if (
hasattr(self, "base_dist")
and (self.base_dist.base_dist.loc.shape == self.loc.shape)
and (self.base_dist.base_dist.scale.shape == self.scale.shape)
):
self.base_dist.base_dist.loc = self.loc
self.base_dist.base_dist.scale = self.scale
else:
base = D.Independent(D.Normal(self.loc, self.scale), self._event_dims)
super().__init__(base, self._t)
@property
def mode(self):
m = self.base_dist.base_dist.mean
for t in self.transforms:
m = t(m)
return m
although such a fix flies in the face of this comment from the PyTorch devs.
Checklist
- [x] I have checked that there is no similar issue in the repo (required)
- [x] I have read the documentation (required)
- [x] I have provided a minimal working example to reproduce the bug (required)
Any chance this is solved by #2198? If so let's redirect the discussion to #2186
I don't think this issue relates to the mode or the mean of the distribution (as I think those are not used in SAC, but I could be wrong).
The logp seems to be the core of these instabilities. I also experienced that in the past. Clamping tricks are helpful but we have to be careful on how we do this. I would suggest looking around at how others implent this and see what works best while still being a bit mathmatically grounded.
For example this is rllib's implementation, with some arbitrary constants in the code https://github.com/ray-project/ray/blob/e6e21ac2bba8b88c66c88b553a40b21a1c78f0a4/rllib/models/torch/torch_distributions.py#L275-L284
This is stable baseline's
def log_prob(self, actions: th.Tensor, gaussian_actions: Optional[th.Tensor] = None) -> th.Tensor:
# Inverse tanh
# Naive implementation (not stable): 0.5 * torch.log((1 + x) / (1 - x))
# We use numpy to avoid numerical instability
if gaussian_actions is None:
# It will be clipped to avoid NaN when inversing tanh
gaussian_actions = TanhBijector.inverse(actions)
# Log likelihood for a Gaussian distribution
log_prob = super(SquashedDiagGaussianDistribution, self).log_prob(gaussian_actions)
# Squash correction (from original SAC implementation)
# this comes from the fact that tanh is bijective and differentiable
log_prob -= th.sum(th.log(1 - actions ** 2 + self.epsilon), dim=1)
return log_prob
Very similar to rllib's but without the intemidiate clamping trick.
OK got it I played a lot with Tanh transform back in the days and the TLDR is that anything you do (clamp or no clamp) will degrade performance for someone. What about giving the option to use the "safe" tanh (with clamping) or not? Another option is: cast values from float32 to float64, do the tanh, cast back to float32. This could also be controlled via a flag in the TanhNormal constructor.
>>> x = torch.full((1,), 10.0)
>>> x.tanh().atanh()
tensor([inf])
>>> x.double().tanh().atanh().float()
tensor([10.])
Note that in practice this is unlikely to help in many cases, since casting to float after tanh() still screws up everything:
>>> x.double().tanh().float().double().atanh().float()
tensor([inf])
I like the idea of letting the user choose between the mathematically pure and the empirically more stable version with a flag. I wouldn't call it safe maybe as this is already used in other contexts, what about clamp_logp