Loss using SwaV getting stuck at 6.24
I've used the given example for SWAV using ResNet50
class SwaV(pl.LightningModule):
def __init__(self):
super().__init__()
resnet = torchvision.models.resnet50()
self.backbone = nn.Sequential(*list(resnet.children())[:-1])
self.projection_head = SwaVProjectionHead(2048, 512, 128)
self.prototypes = SwaVPrototypes(128, n_prototypes=512)
self.criterion = SwaVLoss()
def forward(self, x):
x = self.backbone(x).flatten(start_dim=1)
x = self.projection_head(x)
x = nn.functional.normalize(x, dim=1, p=2)
p = self.prototypes(x)
return p
def training_step(self, batch, batch_idx):
self.prototypes.normalize()
crops, _, _ = batch
multi_crop_features = [self.forward(x.to(self.device)) for x in crops]
high_resolution = multi_crop_features[:2]
low_resolution = multi_crop_features[2:]
loss = self.criterion(high_resolution, low_resolution)
self.log("loss", loss, on_epoch=True)
log_dict = {"train_loss": loss}
return {"loss": loss, "log": log_dict}
def configure_optimizers(self):
optim = torch.optim.Adam(self.parameters(), lr=0.001)
return optim
model = SwaV()
dataset = LightlyDataset("data/")
collate_fn = SwaVCollateFunction()
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=32,
collate_fn=collate_fn,
shuffle=True,
drop_last=True,
num_workers=8,
)
trainer = pl.Trainer(
max_epochs=400, gpus=[0],
logger=wandb_logger,
log_every_n_steps=2,
callbacks=[checkpoint_callback],
)
...
But the loss has been stuck in 6.24 and never moved anymore.
I edited the following as well:
- changed learning rate from
0.01to0.6 - Added
CosineAnnealing - made
resnet50aspretrained=True - modified the
SwaVCollateFunction()
But it still got stuck with 6.24

One thing you should keep in mind, that these techniques are very much batch_size sensitive. I would recommend you increase batch size at least to 256 or 512 to start seeing the results. I suspect your model has collapsed due to the small batch size.
Please let me know if you have other questions.
Okay. I tried changing the batch size @Atharva-Phatak But I am getting this error
RuntimeError: Caught RuntimeError in replica 0 on device 0.
Original Traceback (most recent call last):
File "/home/james_sarmiento/anaconda3/envs/swav/lib/python3.6/site-packages/torch/nn/parallel/parallel_apply.py", line 61, in _worker
output = module(*input, **kwargs)
File "/home/james_sarmiento/anaconda3/envs/swav/lib/python3.6/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/james_sarmiento/anaconda3/envs/swav/lib/python3.6/site-packages/pytorch_lightning/overrides/data_parallel.py", line 63, in forward
output = super().forward(*inputs, **kwargs)
File "/home/james_sarmiento/anaconda3/envs/swav/lib/python3.6/site-packages/pytorch_lightning/overrides/base.py", line 81, in forward
output = self.module.training_step(*inputs, **kwargs)
File "test_swav.py", line 37, in training_step
multi_crop_features = [self.forward(x) for x in crops]
File "test_swav.py", line 37, in <listcomp>
multi_crop_features = [self.forward(x) for x in crops]
File "test_swav.py", line 31, in forward
p = self.prototypes(x)
File "/home/james_sarmiento/anaconda3/envs/swav/lib/python3.6/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/james_sarmiento/anaconda3/envs/swav/lib/python3.6/site-packages/lightly/models/modules/heads.py", line 55, in forward
return self.layers(x)
File "/home/james_sarmiento/anaconda3/envs/swav/lib/python3.6/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/james_sarmiento/anaconda3/envs/swav/lib/python3.6/site-packages/torch/nn/modules/linear.py", line 93, in forward
return F.linear(input, self.weight, self.bias)
File "/home/james_sarmiento/anaconda3/envs/swav/lib/python3.6/site-packages/torch/nn/functional.py", line 1692, in linear
output = input.matmul(weight.t())
RuntimeError: Output 165 of BroadcastBackward is a view and its base or another view of its base has been modified inplace. This view is the output of a function that returns multiple views. Such functions do not allow the output views to be modified inplace. You should replace the inplace operation by an out-of-place one.
Are you running on a DDP setup and lightning ? Can you paste the code you wrote for lightning ?
I have the same problem. SwaV should work well with small batch sizes unlike models that do batch-wise negative sampling like SimCLR. However, there's been a missing feature in the SwaV implementation that's crucial for training with small batch sizes (https://github.com/lightly-ai/lightly/issues/1006). Hopefully once it's resolved we should see it performing better.
Support for SwaV with small batch sizes was added in #1010