jepa
jepa copied to clipboard
Why are mask tokens zero‑initialized in V‑JEPA while they are randomly initialized (trunc_normal) in I‑JEPA?
While digging through the predictors I noticed a small but interesting difference in the way the mask tokens are initialized:
# I‑JEPA (vision_transformer_predictor.py)
self.mask_token = nn.Parameter(torch.zeros(1, 1, predictor_embed_dim))
trunc_normal_(self.mask_token, std=init_std) # ≈ N(0, 0.02²)
#V‑JEPA (vision_transformer_predictor.py)
self.mask_tokens = nn.ParameterList([
nn.Parameter(torch.zeros(1, 1, predictor_embed_dim))
for i in range(num_mask_tokens)
])
# zero_init_mask_tokens=True, this code doesn't work.
if self.predictor_pos_embed is not None:
self._init_pos_embed(self.predictor_pos_embed.data)
self.init_std = init_std
if not zero_init_mask_tokens:
for mt in self.mask_tokens:
trunc_normal_(mt, std=init_std)
# zero_init_mask_tokens=True is the default, so they remain exactly zero
Could you share the motivation behind switching to zero initialization for the video version?
-
Did zero‑init improve training stability for long spatio‑temporal sequences or multi‑mask‑token setups?
-
Have you compared convergence speed or final performance between zero‑init and trunc‑normal on the same video benchmarks?
Thanks again for your time and for the great work!