transformers icon indicating copy to clipboard operation
transformers copied to clipboard

DPTForDepthEstimation with Dinov2 does not use pretrained weights

Open ducha-aiki opened this issue 1 year ago • 2 comments

System Info

  • transformers version: 4.39.3
  • Platform: Linux-4.19.0-26-cloud-amd64-x86_64-with-glibc2.17
  • Python version: 3.8.12
  • Huggingface_hub version: 0.21.3
  • Safetensors version: 0.4.2
  • Accelerate version: 0.28.0
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.2.1+cu121 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?: no
  • Using distributed or parallel set-up in script?: no

Who can help?

@amyeroberts @stevhliu

Information

  • [X] The official example scripts

Tasks

  • [X] An officially supported task in the examples folder (such as GLUE/SQuAD, ...)

Reproduction

The code provided in the https://huggingface.co/docs/transformers/en/model_doc/dpt#usage-tips example and likely merged from https://github.com/huggingface/transformers/issues/26057 does not in fact use Dinov2 pretrained weights and instead uses a randomly initialized backbone.

To reproduce:

import matplotlib.pyplot as put
from transformers import Dinov2Config, DPTConfig, DPTForDepthEstimation

backbone_config = Dinov2Config.from_pretrained("facebook/dinov2-base", out_features=["stage1", "stage2", "stage3", "stage4"])
config = DPTConfig(backbone_config=backbone_config)

model = DPTForDepthEstimation(config=config)

plt.imshow(model.backbone.embeddings.patch_embeddings.projection.weight[0, 0].detach().cpu().numpy())

Let's visualize the actual Dinov2 filters:

import matplotlib.pyplot as put
from transformers import AutoImageProcessor, AutoModel

model = AutoModel.from_pretrained('facebook/dinov2-base')

plt.imshow(model.embeddings.patch_embeddings.projection.weight[0, 0].detach().cpu().numpy())

Expected behavior

I would expect that DPTForDepthEstimation would take pretrained DINOv2 weights, not random ones.

Here are the visualizations for DPTForDepthEstimation DINOv2 and real DINOv2:

DPT DINOv2:

image

Real DINOv2

image

Or at least I would expect the documentation to tell me that there is no power of DINOv2 here.

ducha-aiki avatar Apr 05 '24 12:04 ducha-aiki

The workaround I am using now, is to start from random weights, then load a one of DPT-for depth from Facebook and initialize the backbone from there.

model_nyu = DPTForDepthEstimation.from_pretrained("facebook/dpt-dinov2-base-nyu")
my_model.backbone.load_state_dict(model_nyu.backbone.state_dict())

ducha-aiki avatar Apr 06 '24 08:04 ducha-aiki

Yes that's right, we could clarify this in the docs. However, @amyeroberts was working on passing use_pretrained_backbone=True in order to initialize the backbone with pre-trained weights. This is not supported as of now, but the plan is that people can pass this additional flag to the config in order to instantiate the backbone with pre-trained weights.

Would you be willing to open a PR to clarify this? The docs is here.

NielsRogge avatar Apr 06 '24 09:04 NielsRogge

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

github-actions[bot] avatar Jun 01 '24 08:06 github-actions[bot]

This is being worked on at #31145

NielsRogge avatar Jun 01 '24 08:06 NielsRogge

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

github-actions[bot] avatar Jun 26 '24 08:06 github-actions[bot]