tdmpc2 icon indicating copy to clipboard operation
tdmpc2 copied to clipboard

The Q networks NormedLinear layers are not initialized.

Open niamorg opened this issue 9 months ago • 1 comments

Since the merge enabling torch.compile, the line self.apply(init.weight_init) in WorldModel.__init__() does not initialize the Q networks layers since self._Qs.params is not a nn.ParameterList anymore, but a TensorDictParams (see init.weight_init). They are only initialized through the default initialization of nn.Linear.

The last layer of each Q network is initialized by the next line init.zero_([self._reward[-1].weight, self._Qs.params["2", "weight"]]), so for this one it's alright.

WorldModel.py

	def __init__(self, cfg):
		super().__init__()
          [...]
		self._Qs = layers.Ensemble([layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], max(cfg.num_bins, 1), dropout=cfg.dropout) for _ in range(cfg.num_q)])
		self.apply(init.weight_init)
		init.zero_([self._reward[-1].weight, self._Qs.params["2", "weight"]])

layers.py

class Ensemble(nn.Module):
	"""
	Vectorized ensemble of modules.
	"""

	def __init__(self, modules, **kwargs):
		super().__init__()
		# combine_state_for_ensemble causes graph breaks
		self.params = from_modules(*modules, as_module=True)
		with self.params[0].data.to("meta").to_module(modules[0]):
			self.module = deepcopy(modules[0])
		self._repr = str(modules[0])
		self._n = len(modules)

init.py

def weight_init(m):
	"""Custom weight initialization for TD-MPC2."""
	if isinstance(m, nn.Linear):
		nn.init.trunc_normal_(m.weight, std=0.02)
		if m.bias is not None:
			nn.init.constant_(m.bias, 0)
	elif isinstance(m, nn.Embedding):
		nn.init.uniform_(m.weight, -0.02, 0.02)
	elif isinstance(m, nn.ParameterList):
		for i,p in enumerate(m):
			if p.dim() == 3: # Linear
				nn.init.trunc_normal_(p, std=0.02) # Weight
				nn.init.constant_(m[i+1], 0) # Bias

niamorg avatar Jun 26 '25 00:06 niamorg

Great catch! Let me look into this and get back to you soon. I suspect that it doesn't really matter in practice but would be good to check.

nicklashansen avatar Jun 27 '25 18:06 nicklashansen