tdmpc2
tdmpc2 copied to clipboard
The Q networks NormedLinear layers are not initialized.
Since the merge enabling torch.compile, the line self.apply(init.weight_init) in WorldModel.__init__() does not initialize the Q networks layers since self._Qs.params is not a nn.ParameterList anymore, but a TensorDictParams (see init.weight_init). They are only initialized through the default initialization of nn.Linear.
The last layer of each Q network is initialized by the next line init.zero_([self._reward[-1].weight, self._Qs.params["2", "weight"]]), so for this one it's alright.
WorldModel.py
def __init__(self, cfg):
super().__init__()
[...]
self._Qs = layers.Ensemble([layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], max(cfg.num_bins, 1), dropout=cfg.dropout) for _ in range(cfg.num_q)])
self.apply(init.weight_init)
init.zero_([self._reward[-1].weight, self._Qs.params["2", "weight"]])
layers.py
class Ensemble(nn.Module):
"""
Vectorized ensemble of modules.
"""
def __init__(self, modules, **kwargs):
super().__init__()
# combine_state_for_ensemble causes graph breaks
self.params = from_modules(*modules, as_module=True)
with self.params[0].data.to("meta").to_module(modules[0]):
self.module = deepcopy(modules[0])
self._repr = str(modules[0])
self._n = len(modules)
init.py
def weight_init(m):
"""Custom weight initialization for TD-MPC2."""
if isinstance(m, nn.Linear):
nn.init.trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Embedding):
nn.init.uniform_(m.weight, -0.02, 0.02)
elif isinstance(m, nn.ParameterList):
for i,p in enumerate(m):
if p.dim() == 3: # Linear
nn.init.trunc_normal_(p, std=0.02) # Weight
nn.init.constant_(m[i+1], 0) # Bias
Great catch! Let me look into this and get back to you soon. I suspect that it doesn't really matter in practice but would be good to check.