jepa icon indicating copy to clipboard operation
jepa copied to clipboard

Why are mask tokens zero‑initialized in V‑JEPA while they are randomly initialized (trunc_normal) in I‑JEPA?

Open k007ke opened this issue 10 months ago • 0 comments

While digging through the predictors I noticed a small but interesting difference in the way the mask tokens are initialized:

# I‑JEPA (vision_transformer_predictor.py)
self.mask_token = nn.Parameter(torch.zeros(1, 1, predictor_embed_dim))
trunc_normal_(self.mask_token, std=init_std)   # ≈ N(0, 0.02²)

 #V‑JEPA (vision_transformer_predictor.py)
self.mask_tokens = nn.ParameterList([
                nn.Parameter(torch.zeros(1, 1, predictor_embed_dim))
                for i in range(num_mask_tokens)
            ])
  # zero_init_mask_tokens=True, this code doesn't work.
  if self.predictor_pos_embed is not None:
            self._init_pos_embed(self.predictor_pos_embed.data)
        self.init_std = init_std
        if not zero_init_mask_tokens:
            for mt in self.mask_tokens:
                trunc_normal_(mt, std=init_std)
# zero_init_mask_tokens=True is the default, so they remain exactly zero

Could you share the motivation behind switching to zero initialization for the video version?

  • Did zero‑init improve training stability for long spatio‑temporal sequences or multi‑mask‑token setups?

  • Have you compared convergence speed or final performance between zero‑init and trunc‑normal on the same video benchmarks?

Thanks again for your time and for the great work!

k007ke avatar May 07 '25 11:05 k007ke