lightly icon indicating copy to clipboard operation
lightly copied to clipboard

Loss using SwaV getting stuck at 6.24

Open sarmientoj24 opened this issue 4 years ago • 3 comments

I've used the given example for SWAV using ResNet50

class SwaV(pl.LightningModule):
    def __init__(self):
        super().__init__()
        resnet = torchvision.models.resnet50()
        self.backbone = nn.Sequential(*list(resnet.children())[:-1])
        self.projection_head = SwaVProjectionHead(2048, 512, 128)
        self.prototypes = SwaVPrototypes(128, n_prototypes=512)
        self.criterion = SwaVLoss()

    def forward(self, x):
        x = self.backbone(x).flatten(start_dim=1)
        x = self.projection_head(x)
        x = nn.functional.normalize(x, dim=1, p=2)
        p = self.prototypes(x)
        return p

    def training_step(self, batch, batch_idx):
        self.prototypes.normalize()
        crops, _, _ = batch
        multi_crop_features = [self.forward(x.to(self.device)) for x in crops]
        high_resolution = multi_crop_features[:2]
        low_resolution = multi_crop_features[2:]
        loss = self.criterion(high_resolution, low_resolution)
        self.log("loss", loss, on_epoch=True)
        
        log_dict = {"train_loss": loss}

        return {"loss": loss, "log": log_dict}

    def configure_optimizers(self):
        optim = torch.optim.Adam(self.parameters(), lr=0.001)
        return optim

model = SwaV()
dataset = LightlyDataset("data/")
collate_fn = SwaVCollateFunction()

dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=32,
    collate_fn=collate_fn,
    shuffle=True,
    drop_last=True,
    num_workers=8,
)

trainer = pl.Trainer(
    max_epochs=400, gpus=[0],
    logger=wandb_logger,
    log_every_n_steps=2,
    callbacks=[checkpoint_callback],
)
...

But the loss has been stuck in 6.24 and never moved anymore.

I edited the following as well:

  • changed learning rate from 0.01 to 0.6
  • Added CosineAnnealing
  • made resnet50 as pretrained=True
  • modified the SwaVCollateFunction()

But it still got stuck with 6.24

image

sarmientoj24 avatar Mar 11 '22 16:03 sarmientoj24

One thing you should keep in mind, that these techniques are very much batch_size sensitive. I would recommend you increase batch size at least to 256 or 512 to start seeing the results. I suspect your model has collapsed due to the small batch size.

Please let me know if you have other questions.

Atharva-Phatak avatar Mar 13 '22 01:03 Atharva-Phatak

Okay. I tried changing the batch size @Atharva-Phatak But I am getting this error

RuntimeError: Caught RuntimeError in replica 0 on device 0.
Original Traceback (most recent call last):
  File "/home/james_sarmiento/anaconda3/envs/swav/lib/python3.6/site-packages/torch/nn/parallel/parallel_apply.py", line 61, in _worker
    output = module(*input, **kwargs)
  File "/home/james_sarmiento/anaconda3/envs/swav/lib/python3.6/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/james_sarmiento/anaconda3/envs/swav/lib/python3.6/site-packages/pytorch_lightning/overrides/data_parallel.py", line 63, in forward
    output = super().forward(*inputs, **kwargs)
  File "/home/james_sarmiento/anaconda3/envs/swav/lib/python3.6/site-packages/pytorch_lightning/overrides/base.py", line 81, in forward
    output = self.module.training_step(*inputs, **kwargs)
  File "test_swav.py", line 37, in training_step
    multi_crop_features = [self.forward(x) for x in crops]
  File "test_swav.py", line 37, in <listcomp>
    multi_crop_features = [self.forward(x) for x in crops]
  File "test_swav.py", line 31, in forward
    p = self.prototypes(x)
  File "/home/james_sarmiento/anaconda3/envs/swav/lib/python3.6/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/james_sarmiento/anaconda3/envs/swav/lib/python3.6/site-packages/lightly/models/modules/heads.py", line 55, in forward
    return self.layers(x)
  File "/home/james_sarmiento/anaconda3/envs/swav/lib/python3.6/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/james_sarmiento/anaconda3/envs/swav/lib/python3.6/site-packages/torch/nn/modules/linear.py", line 93, in forward
    return F.linear(input, self.weight, self.bias)
  File "/home/james_sarmiento/anaconda3/envs/swav/lib/python3.6/site-packages/torch/nn/functional.py", line 1692, in linear
    output = input.matmul(weight.t())
RuntimeError: Output 165 of BroadcastBackward is a view and its base or another view of its base has been modified inplace. This view is the output of a function that returns multiple views. Such functions do not allow the output views to be modified inplace. You should replace the inplace operation by an out-of-place one.

sarmientoj24 avatar Mar 13 '22 08:03 sarmientoj24

Are you running on a DDP setup and lightning ? Can you paste the code you wrote for lightning ?

Atharva-Phatak avatar Mar 13 '22 23:03 Atharva-Phatak

I have the same problem. SwaV should work well with small batch sizes unlike models that do batch-wise negative sampling like SimCLR. However, there's been a missing feature in the SwaV implementation that's crucial for training with small batch sizes (https://github.com/lightly-ai/lightly/issues/1006). Hopefully once it's resolved we should see it performing better.

ibro45 avatar Dec 16 '22 13:12 ibro45

Support for SwaV with small batch sizes was added in #1010

guarin avatar Feb 07 '23 14:02 guarin