dinov2
dinov2 copied to clipboard
Is this the right way to fine-tune DINOv2?
I am trying to finetune dinov2 for image classification on a custom dataset (medical image dataset) with the objective of increasing accuracy. The problem is that when I use linear evaluation I get an adequate accuracy of almost 75%, however when I try to finetune(the whole backbone) I can never get an accuracy higher than 40%, is there something semantically wrong with how I am trying to finetune this model? I even tried it with cifar10 and got an excellent performance on linear evaluation but a poor performance on fine-tuning. Also when I used the model from the hub and ran the following code snippet, I got "Pre-trained DINO weights are not found in the model's state_dict." so instead I had to load the model from hugging face for fine-tuning the whole backbone :
pretrained_dino_keys = [k for k in model.state_dict() if 'dino' in k]
if pretrained_dino_keys:
print("Pre-trained DINO weights are present in the model's state_dict.")
else:
print("Pre-trained DINO weights are not found in the model's state_dict.")
the following is my code for fine-tuning:
from transformers import Dinov2ForImageClassification
model = Dinov2ForImageClassification.from_pretrained("facebook/dinov2-small-imagenet1k-1-layer")
for param in model.dinov2.parameters():
param.requires_grad = True
for param in model.classifier.parameters():
param.requires_grad = True
# Customize the head for the classification task
num_classes = 10 # Number of classes in the dataset
model.classifier = nn.Linear(768, num_classes).to(device) a linear layer for classification and move to GPU
# Define the loss function
loss_fn = nn.CrossEntropyLoss()
weight_decay = 1e-3
lr = 0.001
step_size = 5
optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
# Create a learning rate scheduler
scheduler = StepLR(optimizer, step_size=step_size, gamma=0.0001)
def make_classification_eval_transform(
*,
resize_size: int = 256,
interpolation=transforms.InterpolationMode.BICUBIC,
crop_size: int = 224,
) -> transforms.Compose:
transforms_list = [
transforms.Resize(resize_size, interpolation=interpolation),
transforms.CenterCrop(crop_size),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
]
return transforms.Compose(transforms_list)
# Use the make_classification_eval_transform function to create the transformation pipeline
transform = make_classification_eval_transform()
# Set up data loaders for training, validation, and test
train_dataset = ImageFolder(root=train_dataset_path, transform=transform)
valid_dataset = ImageFolder(root=valid_dataset_path, transform=transform)
test_dataset = ImageFolder(root=test_dataset_path, transform=transform)
# Modify data loading to move data to the same device as the model
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=2)
valid_loader = DataLoader(valid_dataset, batch_size=32, shuffle=False, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=2)
model = model.to(device)
# Set random seed
torch.manual_seed(1)
# Define the number of epochs
num_epochs = 20
# Initialize lists to store loss and accuracy for each epoch
loss_hist_train = [0.0] * num_epochs
accuracy_hist_train = [0.0] * num_epochs
loss_hist_valid = [0.0] * num_epochs
accuracy_hist_valid = [0.0] * num_epochs
for epoch in range(num_epochs):
model.train()
loss_accumulated_train = 0.0 # Initialize to zero
total_samples_train = 0 # Initialize to zero
correct_predictions_train = 0 # Initialize to zero
for x_batch, y_batch in train_loader:
x_batch = x_batch.to(device)
y_batch = y_batch.to(device)
output = model(x_batch)
logits = output.logits
loss = loss_fn(logits, y_batch)
loss.backward()
optimizer.step()
optimizer.zero_grad()
loss_accumulated_train += loss.item() * y_batch.size(0) # Accumulate as a scalar
total_samples_train += y_batch.size(0)
# Calculate accuracy
predicted = torch.max(logits, 1)[1]
correct_predictions_train += torch.sum(predicted == y_batch).item() # Accumulate as a scalar
loss_hist_train[epoch] = loss_accumulated_train / total_samples_train # Calculate average loss per batch
accuracy_hist_train[epoch] = correct_predictions_train / total_samples_train # Calculate accuracy directly
scheduler.step()
model.eval()
with torch.no_grad():
loss_accumulated_valid = 0.0 # Initialize to zero
total_samples_valid = 0 # Initialize to zero
correct_predictions_valid = 0 # Initialize to zero
for x_batch, y_batch in valid_loader:
x_batch = x_batch.to(device)
y_batch = y_batch.to(device)
output = model(x_batch)
logits = output.logits
loss = loss_fn(logits, y_batch)
loss_accumulated_valid += loss.item() * y_batch.size(0) # Accumulate as a scalar
total_samples_valid += y_batch.size(0)
# Calculate accuracy
predicted = torch.max(logits, 1)[1]
correct_predictions_valid += torch.sum(predicted == y_batch).item() # Accumulate as a scalar
loss_hist_valid[epoch] = loss_accumulated_valid / total_samples_valid # Calculate average loss per batch
accuracy_hist_valid[epoch] = correct_predictions_valid / total_samples_valid # Calculate accuracy directly
print(f'Epoch {epoch + 1} accuracy: {accuracy_hist_train[epoch]:.4f} val_accuracy: {accuracy_hist_valid[epoch]:.4f} loss: {loss_hist_train[epoch]:.4f} val_loss: {loss_hist_valid[epoch]:.4f}')