lightly icon indicating copy to clipboard operation
lightly copied to clipboard

LightlyDataset is not working

Open darthvaddor opened this issue 1 year ago • 6 comments

I've been trying to finetune dino on my dataset:

import copy

import torch
import torchvision
from torch import nn

from lightly.loss import DINOLoss
from lightly.data.collate import DINOCollateFunction
from lightly.models.modules import DINOProjectionHead
from lightly.models.utils import deactivate_requires_grad, update_momentum
from lightly.transforms.dino_transform import DINOTransform
from lightly.utils.scheduler import cosine_schedule
from lightly.data import LightlyDataset

import warnings
from PIL import Image

Image.MAX_IMAGE_PIXELS = None

class DINO(torch.nn.Module):
    def __init__(self, backbone, input_dim):
        super().__init__()
        self.student_backbone = backbone
        self.student_head = DINOProjectionHead(
            input_dim, 512, 64, 2048, freeze_last_layer=1
        )
        self.teacher_backbone = copy.deepcopy(backbone)
        self.teacher_head = DINOProjectionHead(input_dim, 512, 64, 2048)
        deactivate_requires_grad(self.teacher_backbone)
        deactivate_requires_grad(self.teacher_head)

    def forward(self, x):
        y = self.student_backbone(x).flatten(start_dim=1)
        z = self.student_head(y)
        return z

    def forward_teacher(self, x):
        y = self.teacher_backbone(x).flatten(start_dim=1)
        z = self.teacher_head(y)
        return z

# resnet = torchvision.models.resnet18()
# backbone = nn.Sequential(*list(resnet.children())[:-1])
# input_dim = 512
# instead of a resnet you can also use a vision transformer backbone as in the
# original paper (you might have to reduce the batch size in this case):
backbone = torch.hub.load('facebookresearch/dino:main', 'dino_vits16', pretrained=False)
input_dim = backbone.embed_dim

model = DINO(backbone, input_dim)

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

class DINO(torch.nn.Module):
    def __init__(self, backbone, input_dim):
        super().__init__()
        self.student_backbone = backbone
        self.student_head = DINOProjectionHead(
            input_dim, 512, 64, 2048, freeze_last_layer=1
        )
        self.teacher_backbone = copy.deepcopy(backbone)
        self.teacher_head = DINOProjectionHead(input_dim, 512, 64, 2048)
        deactivate_requires_grad(self.teacher_backbone)
        deactivate_requires_grad(self.teacher_head)

    def forward(self, x):
        y = self.student_backbone(x).flatten(start_dim=1)
        z = self.student_head(y)
        return z

    def forward_teacher(self, x):
        y = self.teacher_backbone(x).flatten(start_dim=1)
        z = self.teacher_head(y)
        return z

# resnet = torchvision.models.resnet18()
# backbone = nn.Sequential(*list(resnet.children())[:-1])
# input_dim = 512
# instead of a resnet you can also use a vision transformer backbone as in the
# original paper (you might have to reduce the batch size in this case):
backbone = torch.hub.load('facebookresearch/dino:main', 'dino_vitb16', pretrained=False)
input_dim = backbone.embed_dim

model = DINO(backbone, input_dim)

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

transform = DINOTransform()

# we ignore object detection annotations by setting target_transform to return 0
def target_transform(t):
    return 0

# dataset = torchvision.datasets.VOCDetection(
#     "/content/drive/MyDrive/Dataset",
#     download=True,
#     transform=transform,
#     target_transform=target_transform,
# )
# or create a dataset from a folder containing images or videos:
dataset = LightlyDataset("/content/drive/MyDrive/Dataset")

dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=16,
    shuffle=True,
    drop_last=True,
    num_workers=2,
)

criterion = DINOLoss(
    output_dim=2048,
    warmup_teacher_temp_epochs=5,
)
# move loss to correct device because it also contains parameters
criterion = criterion.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

epochs = 10

print("Starting Training")
for epoch in range(epochs):
    total_loss = 0
    momentum_val = cosine_schedule(epoch, epochs, 0.996, 1)
    for batch in dataloader:
        views = batch[0]
        update_momentum(model.student_backbone, model.teacher_backbone, m=momentum_val)
        update_momentum(model.student_head, model.teacher_head, m=momentum_val)
        views = [view.to(device) for view in views]
        global_views = views[:2]
        teacher_out = [model.forward_teacher(view) for view in global_views]
        student_out = [model.forward(view) for view in views]
        loss = criterion(teacher_out, student_out, epoch=epoch)
        total_loss += loss.detach()
        loss.backward()
        # We only cancel gradients of student head.
        model.student_head.cancel_last_layer_gradients(current_epoch=epoch)
        optimizer.step()
        optimizer.zero_grad()

    avg_loss = total_loss / len(dataloader)
    print(f"epoch: {epoch:>02}, loss: {avg_loss:.5f}")

I've used the exact same code given in the documentation. No changes made. My Dataset Folder contains .jpg images The error:

Starting Training
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
[<ipython-input-23-f629995aaa17>](https://localhost:8080/#) in <cell line: 2>()
      3     total_loss = 0
      4     momentum_val = cosine_schedule(epoch, epochs, 0.996, 1)
----> 5     for batch in dataloader:
      6         views = batch[0]
      7         update_momentum(model.student_backbone, model.teacher_backbone, m=momentum_val)

3 frames
[/usr/local/lib/python3.10/dist-packages/torch/_utils.py](https://localhost:8080/#) in reraise(self)
    704             # instantiate since we don't know how to
    705             raise RuntimeError(msg) from None
--> 706         raise exception
    707 
    708 

TypeError: Caught TypeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/_utils/worker.py", line 309, in _worker_loop
    data = fetcher.fetch(index)  # type: ignore[possibly-undefined]
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/_utils/fetch.py", line 55, in fetch
    return self.collate_fn(data)
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/_utils/collate.py", line 317, in default_collate
    return collate(batch, collate_fn_map=default_collate_fn_map)
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/_utils/collate.py", line 174, in collate
    return [collate(samples, collate_fn_map=collate_fn_map) for samples in transposed]  # Backwards compatibility.
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/_utils/collate.py", line 174, in <listcomp>
    return [collate(samples, collate_fn_map=collate_fn_map) for samples in transposed]  # Backwards compatibility.
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/_utils/collate.py", line 192, in collate
    raise TypeError(default_collate_err_msg_format.format(elem_type))
TypeError: default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found <class 'PIL.Image.Image'>

darthvaddor avatar Oct 18 '24 05:10 darthvaddor

@japrescott @IgorSusmelj @guarin @philippmwirth @MalteEbner @michal-lightly @huan-lightly-0 @shaundaley39

darthvaddor avatar Oct 18 '24 05:10 darthvaddor

Hey @JAVARSHA thank you for using our library 🙂 It looks like there's a small error in the docs that you uncovered by using LightlyDataset instead VOCDetection. Can you try adding the transform argument to the LightlyDataset?

dataset = LightlyDataset("/content/drive/MyDrive/Dataset", transform=transform)

philippmwirth avatar Oct 18 '24 06:10 philippmwirth

Hey @JAVARSHA thank you for using our library 🙂 It looks like there's a small error in the docs that you uncovered by using LightlyDataset instead VOCDetection. Can you try adding the transform argument to the LightlyDataset?

dataset = LightlyDataset("/content/drive/MyDrive/Dataset", transform=transform)

Thanks, @philippmwirth! This worked perfectly—it's been 15 minutes without any issues so far. I do have a question, though: what’s the difference between uploading the dataset through the Lightly app and using LightlyDataset directly to access the image directory? I apologize if this seems like an amateur question; I’m just curious about the purpose of the app.

darthvaddor avatar Oct 18 '24 06:10 darthvaddor

Great, thanks for the update!

The Lightly app focuses on data curation and selection. You can learn more about it in the docs. The open-source library focuses on self-supervised learning. It has separate docs which can be found here. The underlying goal is to enable you to go from unlabeled data to a production ready machine learning model as efficiently as possible.

philippmwirth avatar Oct 18 '24 07:10 philippmwirth

The doc for MSN has the same issue.

OpenByteDev avatar Oct 19 '24 18:10 OpenByteDev

Thanks for reporting! We should probably sweep through all examples quickly and fix it everywhere.

philippmwirth avatar Oct 20 '24 09:10 philippmwirth

Great, thanks for the update!

The Lightly app focuses on data curation and selection. You can learn more about it in the docs. The open-source library focuses on self-supervised learning. It has separate docs which can be found here. The underlying goal is to enable you to go from unlabeled data to a production ready machine learning model as efficiently as possible.

Hey @philippmwirth ! Another question I have is why is the output dimension of vit small 2048, and not 384

criterion = DINOLoss(
    output_dim=2048,
    warmup_teacher_temp_epochs=5,
)

darthvaddor avatar Oct 21 '24 08:10 darthvaddor

The output dimension of the DINOLoss should correspond to the output dimension of the projection head and not of the backbone. By default it is 65536. For small datasets you might want to reduce the dimension.

guarin avatar Oct 21 '24 08:10 guarin