dinov2 icon indicating copy to clipboard operation
dinov2 copied to clipboard

Is this the right way to fine-tune DINOv2?

Open namrahrehman opened this issue 8 months ago • 16 comments

I am trying to finetune dinov2 for image classification on a custom dataset (medical image dataset) with the objective of increasing accuracy. The problem is that when I use linear evaluation I get an adequate accuracy of almost 75%, however when I try to finetune(the whole backbone) I can never get an accuracy higher than 40%, is there something semantically wrong with how I am trying to finetune this model? I even tried it with cifar10 and got an excellent performance on linear evaluation but a poor performance on fine-tuning. Also when I used the model from the hub and ran the following code snippet, I got "Pre-trained DINO weights are not found in the model's state_dict." so instead I had to load the model from hugging face for fine-tuning the whole backbone :

pretrained_dino_keys = [k for k in model.state_dict() if 'dino' in k]

if pretrained_dino_keys:
    print("Pre-trained DINO weights are present in the model's state_dict.")
    print("Pre-trained DINO weights are not found in the model's state_dict.")

the following is my code for fine-tuning:

from transformers import Dinov2ForImageClassification
model = Dinov2ForImageClassification.from_pretrained("facebook/dinov2-small-imagenet1k-1-layer")
for param in model.dinov2.parameters():
    param.requires_grad = True
for param in model.classifier.parameters():
    param.requires_grad = True
# Customize the head for the classification task
num_classes = 10  # Number of classes in the dataset
model.classifier = nn.Linear(768, num_classes).to(device)  a linear layer for classification and move to GPU

# Define the loss function 
loss_fn = nn.CrossEntropyLoss()  

weight_decay = 1e-3 
lr = 0.001
step_size = 5
optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)

# Create a learning rate scheduler
scheduler = StepLR(optimizer, step_size=step_size, gamma=0.0001)
def make_classification_eval_transform(
    resize_size: int = 256,
    crop_size: int = 224,
) -> transforms.Compose:
    transforms_list = [
        transforms.Resize(resize_size, interpolation=interpolation),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
    return transforms.Compose(transforms_list)

# Use the make_classification_eval_transform function to create the transformation pipeline
transform = make_classification_eval_transform()

# Set up data loaders for training, validation, and test
train_dataset = ImageFolder(root=train_dataset_path, transform=transform)
valid_dataset = ImageFolder(root=valid_dataset_path, transform=transform)
test_dataset = ImageFolder(root=test_dataset_path, transform=transform)

# Modify data loading to move data to the same device as the model
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=2)
valid_loader = DataLoader(valid_dataset, batch_size=32, shuffle=False, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=2)
model = model.to(device)
# Set random seed

# Define the number of epochs
num_epochs = 20

# Initialize lists to store loss and accuracy for each epoch
loss_hist_train = [0.0] * num_epochs
accuracy_hist_train = [0.0] * num_epochs
loss_hist_valid = [0.0] * num_epochs
accuracy_hist_valid = [0.0] * num_epochs

for epoch in range(num_epochs):
    loss_accumulated_train = 0.0  # Initialize to zero
    total_samples_train = 0  # Initialize to zero
    correct_predictions_train = 0  # Initialize to zero

    for x_batch, y_batch in train_loader:
        x_batch = x_batch.to(device)
        y_batch = y_batch.to(device)
        output = model(x_batch)
        logits = output.logits
        loss = loss_fn(logits, y_batch)
        loss_accumulated_train += loss.item() * y_batch.size(0)  # Accumulate as a scalar
        total_samples_train += y_batch.size(0)

        # Calculate accuracy
        predicted = torch.max(logits, 1)[1]
        correct_predictions_train += torch.sum(predicted == y_batch).item()  # Accumulate as a scalar

    loss_hist_train[epoch] = loss_accumulated_train / total_samples_train  # Calculate average loss per batch
    accuracy_hist_train[epoch] = correct_predictions_train / total_samples_train  # Calculate accuracy directly


    with torch.no_grad():
        loss_accumulated_valid = 0.0  # Initialize to zero
        total_samples_valid = 0  # Initialize to zero
        correct_predictions_valid = 0  # Initialize to zero

        for x_batch, y_batch in valid_loader:
            x_batch = x_batch.to(device)
            y_batch = y_batch.to(device)
            output = model(x_batch)
            logits = output.logits
            loss = loss_fn(logits, y_batch)
            loss_accumulated_valid += loss.item() * y_batch.size(0)  # Accumulate as a scalar
            total_samples_valid += y_batch.size(0)

            # Calculate accuracy
            predicted = torch.max(logits, 1)[1]
            correct_predictions_valid += torch.sum(predicted == y_batch).item()  # Accumulate as a scalar

        loss_hist_valid[epoch] = loss_accumulated_valid / total_samples_valid  # Calculate average loss per batch
        accuracy_hist_valid[epoch] = correct_predictions_valid / total_samples_valid  # Calculate accuracy directly

    print(f'Epoch {epoch + 1} accuracy: {accuracy_hist_train[epoch]:.4f} val_accuracy: {accuracy_hist_valid[epoch]:.4f} loss: {loss_hist_train[epoch]:.4f} val_loss: {loss_hist_valid[epoch]:.4f}')

namrahrehman avatar Oct 23 '23 14:10 namrahrehman