dinov2 icon indicating copy to clipboard operation
dinov2 copied to clipboard

Semantic segmentation

Open anshkumar opened this issue 2 years ago • 9 comments

I'm not able to find code for Semantic segmentation. In the paper it's written that:

 a linear layer is trained to predict class logits from a patch tokens. It is used to produce a low-
resolution logit map (eg 32x32 for a model with patch size 16), which is then upsampled to full resolution
(512x512) to obtain a segmentation map. 

Does this mean a Linear layer with 32*32 = 1024 output classes need to be trained? What about n_last_blocks_list = [1, 4] and n_last_blocks = max(n_last_blocks_list) ? Does that need to be changed to n_last_blocks_list = [1, 1] and n_last_blocks = max(n_last_blocks_list) ?

Is there any sample code for semantic segmentation ?

anshkumar avatar Apr 19 '23 09:04 anshkumar

This is the paragraph. I don't think 32x32 is linked to the number of classes. It is the low-resolution of the logit map.

Paper

The dataset seems to be ADE-20k, which should have 3688 classes.

Paper

woctezuma avatar Apr 19 '23 11:04 woctezuma

32 x 16 = 512 so starting with a cropped image 512x512 pixel you would end up with [Batch size, # of patches, # of classes]. So [1, 32x32, #of classes]. Where # of classes would be the classes you fine tune it on.

I dont think you want to touch the intermediate layers, just train a head that learns the mapping btw the output of the transformer stack to segmentation label.

ccharest93 avatar Apr 19 '23 13:04 ccharest93

In the Linear class, I can do the following:

nn.Linear(in_dim, 32x32)

Where, to get that # of classes in the output dims ?

anshkumar avatar Apr 19 '23 13:04 anshkumar

You probably would prefer 32*32*N_cls to predict a 32x32 logit map for N_cls classes.

See for instance how it is written for SegFormer: (H/4)*(W/4)*N_cls

SegFormer

SegFormer

To upsample the map and take the argmax, you may refer to 🤗's doc about Semantic segmentation.

Take everything I write with a grain of salt though.

woctezuma avatar Apr 19 '23 19:04 woctezuma

The simplest example for semantic segmentation task head I've done using patch_features:

import torch

class LinearClassifierToken(torch.nn.Module):
    def __init__(self,n_tokens, in_channels,nc=1,tokenW=32,tokenH=32):
        super(LinearClassifierToken, self).__init__()
        self.in_channels=in_channels
        self.W=tokenW
        self.H=tokenH
        self.nc=nc
        self.conv=torch.nn.Conv2d(in_channels,nc,(1,1))
    def forward(self,x):
        return self.conv(x.reshape(-1,self.H,self.W,self.in_channels).permute(0,3,1,2))

classlayer=LinearClassifierToken(1024,768,32,32).cuda()
optimizer=torch.optim.Adam(classlayer.parameters())

lossfn=torch.nn.BCEWithLogitsLoss()
dinov2_vitb14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14').cuda()
dataloader=... #add dataloader here
for data in loader:
    images, masks=data
    target=torch.nn.functional.interpolate(masks, (32,32)).cuda()
    imagesnorm=(images.cuda()-torch.tensor([0.485, 0.456, 0.406]).view(1,3,1,1).cuda())/torch.tensor([0.229, 0.224, 0.225]).view(1,3,1,1).cuda()
    with torch.no_grad():
        features=dinov2_vits14.forward_features(torch.nn.functional.interpolate(imagesnorm,(448,448)).cuda())['x_norm_patchtokens']
    preds=classlayer(features)
    loss=lossfn(preds,target)
    loss.backward()
    print(loss)
    optimizer.step()

Alexankharin avatar Apr 21 '23 06:04 Alexankharin

@Alexankharin But this requires using a conv operation, but the paper specifically specifies using a dense linear layer. The only way I can think of doing it is as follows:

features = feature_model(images)
outputs = linear_classifiers(features)
out_min = outputs.min(dim=-1)[0].reshape((-1,1))
out_max = outputs.max(dim=-1)[0].reshape((-1,1))
outputs = (outputs - out_min) / (out_max - out_min)
outputs = outputs.reshape((-1, 32, 32)).view((-1, 1, 32,32))
outputs = F.interpolate(outputs, size=(img_h, img_w), mode='bilinear', align_corners=False)
outputs = outputs.squeeze(1)
outputs = outputs * num_classes
outputs = outputs.to(torch.int)

anshkumar avatar Apr 21 '23 08:04 anshkumar

Probably I understood paper wrong, but thought it was mentioned linear classification over features patch-wise. If that is so, 1x1 convolution on unrolled patches is mathematically equivalent to linear classification over patch features

@Alexankharin But this requires using a conv operation, but the paper specifically specifies using a dense linear layer. The only way I can think of doing it is as follows:

features = feature_model(images)
outputs = linear_classifiers(features)
out_min = outputs.min(dim=-1)[0].reshape((-1,1))
out_max = outputs.max(dim=-1)[0].reshape((-1,1))
outputs = (outputs - out_min) / (out_max - out_min)
outputs = outputs.reshape((-1, 32, 32)).view((-1, 1, 32,32))
outputs = F.interpolate(outputs, size=(img_h, img_w), mode='bilinear', align_corners=False)
outputs = outputs.squeeze(1)
outputs = outputs * num_classes
outputs = outputs.to(torch.int)

Alexankharin avatar Apr 21 '23 08:04 Alexankharin

The fact that the layer is linear does not really matter, it is just a way to say that DINOv2's frozen features are really good, so that you can train a simple head and get good results. 😄

If Alexankharin's simple code gives good results, then it is fine. Plus the explanation is probably correct.

woctezuma avatar Apr 21 '23 09:04 woctezuma

The fact that the layer is linear does not really matter, it is just a way to say that DINOv2's frozen features are really good, so that you can train a simple head and get good results. 😄

Exactly. Any valid head should be fine. Linear is the easiest to train, but a larger one will get better results.

Probably I understood paper wrong, but thought it was mentioned linear classification over features patch-wise.

Spot on too. The linear head is applied separately to each patch token, ie it is also a 1x1 convolution.

TimDarcet avatar Apr 25 '23 08:04 TimDarcet

Closing as answered (and keeping track in #55).

patricklabatut avatar May 11 '23 21:05 patricklabatut

can confirm that a 1x1 convolution on unrolled patches is mathematically equivalent to a linear layer. No information on neighboring patches is considered in encoding each patch and there are no edge effects due to the 1x1 kernel so there is no need for padding. Number of parameters and their input and outputs are exactly the same.

pranavraja99 avatar May 25 '23 17:05 pranavraja99

I borrow from Alexankharin and U Net comcept to decode it

class LinearClassifierToken(nn.Module):
    def __init__(self, in_channels, num_chanel=2, tokenW=32, tokenH=32):
        super(LinearClassifierToken, self).__init__()
        self.in_channels=in_channels
        self.W=tokenW
        self.H=tokenH
        self.nc=num_chanel
        self.conv=torch.nn.Conv2d(in_channels,num_chanel,(1,1))
    def forward(self,x):
        return self.conv(x.reshape(-1,self.H,self.W,self.in_channels).permute(0,3,1,2))
class DinoV2(nn.Module):
    def __init__(self, num_class=1) -> None:
        super().__init__()
        self.dinov2 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14')
        for param in self.dinov2.parameters():
            param.requires_grad = False
        n=512
        self.classlayer_448 = LinearClassifierToken(in_channels=768,num_chanel=n,tokenW=32,tokenH=32)
        self.classlayer_224 = LinearClassifierToken(in_channels=384,num_chanel=n,tokenW=16,tokenH=16)
        self.selu = nn.SELU()
        self.to_448 = nn.Sequential(
            nn.Conv2d(n,n,kernel_size=7,stride=1,padding=1,bias=False),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(n,n//2,kernel_size=3,stride=1,padding=1,bias=False),
		    nn.BatchNorm2d(n//2),
			nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(n//2,n//4,kernel_size=3,stride=1,padding=1,bias=False),
		    nn.BatchNorm2d(n//4),
			nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(n//4,n//8,kernel_size=3,stride=1,padding=1,bias=False),
		    nn.BatchNorm2d(n//8),
			nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(n//8,n//16,kernel_size=3,stride=1,padding=1,bias=False),
			nn.ReLU(inplace=True)
        )
        self.to_224 = nn.Sequential(
            nn.Conv2d(n,n,kernel_size=5,stride=1,padding=1,bias=False),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(n,n//2,kernel_size=3,stride=1,padding=1,bias=False),
		    nn.BatchNorm2d(n//2),
			nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(n//2,n//4,kernel_size=3,stride=1,padding=1,bias=False),
		    nn.BatchNorm2d(n//4),
			nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(n//4,n//8,kernel_size=3,stride=1,padding=1,bias=False),
		    nn.BatchNorm2d(n//8),
			nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(n//8,n//16,kernel_size=3,stride=1,padding=1,bias=False),
			nn.ReLU(inplace=True)
        )
        self.conv2class = nn.Conv2d(n//16,num_class,kernel_size=3,stride=1,padding=1,bias=True)
    def forward(self, x):
        with torch.no_grad():
            features = self.dinov2.forward_features(x.to("cuda"))['x_norm_patchtokens']
        x = self.selu(self.classlayer_224(features))
        x = self.to_224(x)
        x = self.conv2class(x)
        return x

arkadaz avatar May 29 '23 02:05 arkadaz

The simplest example for semantic segmentation task head I've done using patch_features:

import torch

class LinearClassifierToken(torch.nn.Module):
    def __init__(self,n_tokens, in_channels,nc=1,tokenW=32,tokenH=32):
        super(LinearClassifierToken, self).__init__()
        self.in_channels=in_channels
        self.W=tokenW
        self.H=tokenH
        self.nc=nc
        self.conv=torch.nn.Conv2d(in_channels,nc,(1,1))
    def forward(self,x):
        return self.conv(x.reshape(-1,self.H,self.W,self.in_channels).permute(0,3,1,2))

classlayer=LinearClassifierToken(1024,768,32,32).cuda()
optimizer=torch.optim.Adam(classlayer.parameters())

lossfn=torch.nn.BCEWithLogitsLoss()
dinov2_vitb14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14').cuda()
dataloader=... #add dataloader here
for data in loader:
    images, masks=data
    target=torch.nn.functional.interpolate(masks, (32,32)).cuda()
    imagesnorm=(images.cuda()-torch.tensor([0.485, 0.456, 0.406]).view(1,3,1,1).cuda())/torch.tensor([0.229, 0.224, 0.225]).view(1,3,1,1).cuda()
    with torch.no_grad():
        features=dinov2_vits14.forward_features(torch.nn.functional.interpolate(imagesnorm,(448,448)).cuda())['x_norm_patchtokens']
    preds=classlayer(features)
    loss=lossfn(preds,target)
    loss.backward()
    print(loss)
    optimizer.step()

Hi, I would like to know if this means ground truth segmentation label of the images are needed ? If so, is it possible to peform unsupervised semantic segmentation with DINOv2 ? Many thanks

YScheung avatar Jun 10 '23 09:06 YScheung

The simplest example for semantic segmentation task head I've done using patch_features:

import torch

class LinearClassifierToken(torch.nn.Module):
    def __init__(self,n_tokens, in_channels,nc=1,tokenW=32,tokenH=32):
        super(LinearClassifierToken, self).__init__()
        self.in_channels=in_channels
        self.W=tokenW
        self.H=tokenH
        self.nc=nc
        self.conv=torch.nn.Conv2d(in_channels,nc,(1,1))
    def forward(self,x):
        return self.conv(x.reshape(-1,self.H,self.W,self.in_channels).permute(0,3,1,2))

classlayer=LinearClassifierToken(1024,768,32,32).cuda()
optimizer=torch.optim.Adam(classlayer.parameters())

lossfn=torch.nn.BCEWithLogitsLoss()
dinov2_vitb14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14').cuda()
dataloader=... #add dataloader here
for data in loader:
    images, masks=data
    target=torch.nn.functional.interpolate(masks, (32,32)).cuda()
    imagesnorm=(images.cuda()-torch.tensor([0.485, 0.456, 0.406]).view(1,3,1,1).cuda())/torch.tensor([0.229, 0.224, 0.225]).view(1,3,1,1).cuda()
    with torch.no_grad():
        features=dinov2_vits14.forward_features(torch.nn.functional.interpolate(imagesnorm,(448,448)).cuda())['x_norm_patchtokens']
    preds=classlayer(features)
    loss=lossfn(preds,target)
    loss.backward()
    print(loss)
    optimizer.step()

Hi, I would like to know if this means ground truth segmentation label of the images are needed ? If so, is it possible to peform unsupervised semantic segmentation with DINOv2 ? Many thanks

Yes you need to have the label. For unsupervised segmentation i recommend SAM (segment anything)

arkadaz avatar Jun 10 '23 13:06 arkadaz

In the Linear class, I can do the following:

nn.Linear(in_dim, 32x32)

Where, to get that # of classes in the output dims?

You can use a conv layer instead. Use the number of classes as the number of out channels. Cheers!

here is a sample code

class SegmentationModel(nn.Module):
    def __init__(self, mask_dim=64, num_classes=3):
        super().__init__()
        
        self.mask_dim = mask_dim
        self.num_classes = num_classes
        # Load the DINO model
        self.dino = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14')
        self.dino.cuda()

        # Freeze DINO layers
        for param in self.dino.parameters():
            param.requires_grad = False
        
        self.segmentation_conv = nn.Sequential(
            nn.Conv2d(384, self.num_classes, kernel_size=1),
        )
        
    def forward(self, x):
        batch_size = x.shape[0]
        with torch.no_grad():
            x = self.dino.forward_features(x.cuda())
            x = x['x_norm_patchtokens']
            x = x.permute(0,2,1)
            x = x.reshape(batch_size,384,self.mask_dim,self.mask_dim)
        x = self.segmentation_conv(x)
        x = x.reshape(batch_size,self.mask_dim*self.mask_dim)
        x = torch.sigmoid(x)
        return x

itsprakhar avatar Jun 14 '23 13:06 itsprakhar

Thanks all for the guidelines and the DINOV2 team for the release of the pre-trained model.

I have managed to train a semantic segmentation model in my domain it has achieved exceptional performance, quite close to human capability. The surprise is that the training dataset was just a dozen of masked labels.

Do we have the explanation for this high performance in few-shot learning? Just curious!

DuongTSon avatar Jul 14 '23 03:07 DuongTSon

That is a good question! If you let me speculate for a bit:

-Semantic segmentation could be a combination of two tasks: depth approximation and object classification. That is because when doing semantic segmentation, depth approximation could provide good masks (good object boundary estimation) and object classification could provides a mean to differentiate between the masks produced

Now if we think of DINOv2, its pretext task (combination of DINO and IBOT) forces it to estimate the same object embedding when given two overlapping crops of an image. From this pretext, object classification could be learned through the loss (centering/sharpening that prevents mode collapse "see DINO paper") and forces different object to have different classification. The depth approximation might come from the need to pick a common focus within an image, by that i mean that given slightly different crops, to minimize loss the model needs to come up with a policy to select which object is in focus, (an example might be to always take the object in the foreground), learning this type of policy would lead to depth approximation.

Again this is all speculation, after all models are often blackboxes with emergent properties, but i think it is still interesting to discuss why we believe the properties emerge to guide further design. Let me know what you think!

ccharest93 avatar Jul 14 '23 11:07 ccharest93

Hi folks,

Inspired by this thread, I created a tutorial for people regarding training a linear classifier on top of a frozen DINOv2 for semantic segmentation: https://github.com/NielsRogge/Transformers-Tutorials/blob/master/DINOv2/Train_a_linear_classifier_on_top_of_DINOv2_for_semantic_segmentation.ipynb.

DINOv2 is now available in HF Transformers as well :) https://huggingface.co/docs/transformers/main/model_doc/dinov2

NielsRogge avatar Jul 31 '23 07:07 NielsRogge

Thanks all for the guidelines and the DINOV2 team for the release of the pre-trained model.

I have managed to train a semantic segmentation model in my domain it has achieved exceptional performance, quite close to human capability. The surprise is that the training dataset was just a dozen of masked labels.

Do we have the explanation for this high performance in few-shot learning? Just curious!

Hi, @DuongTSon, I want to use DINOv2 in my domain, but performance very low. Myabe I think it's my codes faults.

Could you please share your codes?

PeterKim1 avatar Oct 12 '23 02:10 PeterKim1

@PeterKim1 Hi, I cannot share the code since it was a project in my company. However you can take a look at this repository https://github.com/itsprakhar/Downstream-Dinov2, it covers the basic structure of a DINO-based models. Some experiences below I have gained when using DINOV2:

  • Segmentation task: The DINOV2 performs well on the big-object segmenation but bad at small objects
  • Classification task: The output embeddings does not work, what works is the path-embedding + CNN-based classification models

Hope it can help you!

DuongTSon avatar Oct 13 '23 15:10 DuongTSon

Hi folks,

Inspired by this thread, I created a tutorial for people regarding training a linear classifier on top of a frozen DINOv2 for semantic segmentation: https://github.com/NielsRogge/Transformers-Tutorials/blob/master/DINOv2/Train_a_linear_classifier_on_top_of_DINOv2_for_semantic_segmentation.ipynb.

DINOv2 is now available in HF Transformers as well :) https://huggingface.co/docs/transformers/main/model_doc/dinov2

Thank you much for this tutorial! Having a ton of fun using it for experimenting in the medical domain. I'm a tinkerer, but not an actual programmer, so apologies if this is a bad question...

Have you considered replicating this tutorial with the new models that include registers? I can't seem to find them available in HF, and I haven't yet been able to get it working appropriately by loading the model from Torch Hub. Seems like these could be quite promising for semantic segmentation — thanks!

mzschwartz88 avatar Dec 30 '23 03:12 mzschwartz88

Hi, they haven't been added yet to HF: https://github.com/huggingface/transformers/issues/27379.

However this should be really easy given the tiny differences of https://github.com/facebookresearch/dinov2/pull/282/files.

NielsRogge avatar Jan 02 '24 14:01 NielsRogge

Probably I understood paper wrong, but thought it was mentioned linear classification over features patch-wise. If that is so, 1x1 convolution on unrolled patches is mathematically equivalent to linear classification over patch features

@Alexankharin But this requires using a conv operation, but the paper specifically specifies using a dense linear layer. The only way I can think of doing it is as follows:

features = feature_model(images)
outputs = linear_classifiers(features)
out_min = outputs.min(dim=-1)[0].reshape((-1,1))
out_max = outputs.max(dim=-1)[0].reshape((-1,1))
outputs = (outputs - out_min) / (out_max - out_min)
outputs = outputs.reshape((-1, 32, 32)).view((-1, 1, 32,32))
outputs = F.interpolate(outputs, size=(img_h, img_w), mode='bilinear', align_corners=False)
outputs = outputs.squeeze(1)
outputs = outputs * num_classes
outputs = outputs.to(torch.int)

You could directly apply a linear layer on a tensor (B,HW,D) instead of reshaping to (B,D,H,W) and using a 1x1 conv "trick" on it. Pytorch allows linear layer to take tensor with more than 2 shapes, provided that the last one corresponds to "in_features". It will then iterate the linear layer over each BHW tensor independently.

For clarification : B = batch dimension D = embedding dimension (e.g 1024 for large dino V2) H = feature map height ( image height//patch size , eg 32 for a model with patch size 16 and image size 512x512) W = feature map width ( image width//patch size ) Hence H*W is the total number of tokens.

tcourat avatar Feb 02 '24 15:02 tcourat

Hi folks, Inspired by this thread, I created a tutorial for people regarding training a linear classifier on top of a frozen DINOv2 for semantic segmentation: https://github.com/NielsRogge/Transformers-Tutorials/blob/master/DINOv2/Train_a_linear_classifier_on_top_of_DINOv2_for_semantic_segmentation.ipynb. DINOv2 is now available in HF Transformers as well :) https://huggingface.co/docs/transformers/main/model_doc/dinov2

Thank you much for this tutorial! Having a ton of fun using it for experimenting in the medical domain. I'm a tinkerer, but not an actual programmer, so apologies if this is a bad question...

Have you considered replicating this tutorial with the new models that include registers? I can't seem to find them available in HF, and I haven't yet been able to get it working appropriately by loading the model from Torch Hub. Seems like these could be quite promising for semantic segmentation — thanks!

I created a sample notebook here that uses torch hub rather than hugging face for creating a custom semantic segmentation model. As a result, you can use the models with registers.

antmedellin avatar Jun 27 '24 01:06 antmedellin

awesome will check it out - thanks!!

mzschwartz88 avatar Jun 27 '24 12:06 mzschwartz88