CoMAE icon indicating copy to clipboard operation
CoMAE copied to clipboard

Add HF integration

Open NielsRogge opened this issue 5 months ago • 0 comments

Hi @yangjiangeyjg,

Thanks for this nice work! I see the checkpoints are currently on Google Drive, this PR aims to make your models discoverable from https://huggingface.co/models?pipeline_tag=image-feature-extraction.

I wrote a quick PoC to showcase that you can easily have integration with the 🤗 hub so that you can automatically load the various CoMAE models using from_pretrained (and push them using push_to_hub), track download numbers for your models (similar to models in the Transformers library), and have nice model cards on a per-model basis. It leverages the PyTorchModelHubMixin class which allows to inherits these methods.

Usage is as follows:

from models_cpc import MaskedAutoencoderViT

# define model
network = MaskedAutoencoderViT(...)

# equip with weights
model.load_state_dict(...)

# push to the hub
model.push_to_hub("MCG-NJU/comae-base")

# reload
model = MaskedAutoencoderViT.from_pretrained("MCG-NJU/comae-base")

This means people don't need to manually download a checkpoint first in their local environment, it just loads automatically from the hub. Checkpoints could be pushed to https://huggingface.co/MCG-NJU.

Would you be interested in this integration?

Kind regards,

Niels

Note

Please don't merge this PR before pushing a model to the hub :)

NielsRogge avatar Sep 01 '24 09:09 NielsRogge