PyTorch-Pretrained-ViT
PyTorch-Pretrained-ViT copied to clipboard
About torch.no_gard()
Hi, Thanks for this implementation. I saw the parameters of nn.Linear() are set to no_gard() in models.py Line:139.
@torch.no_grad()
def init_weights(self):
def _init(m):
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight) # _trunc_normal(m.weight, std=0.02) # from .initialization import _trunc_normal
if hasattr(m, 'bias') and m.bias is not None:
nn.init.normal_(m.bias, std=1e-6) # nn.init.constant(m.bias, 0)
self.apply(_init)
nn.init.constant_(self.fc.weight, 0)
nn.init.constant_(self.fc.bias, 0)
nn.init.normal_(self.positional_embedding.pos_embedding, std=0.02) # _trunc_normal(self.positional_embedding.pos_embedding, std=0.02)
nn.init.constant_(self.class_token, 0)
Does this mean this pro only supports eval? These parameters should be trainable if I want train ViT on my own dataset?