CoMAE
CoMAE copied to clipboard
Add HF integration
Hi @yangjiangeyjg,
Thanks for this nice work! I see the checkpoints are currently on Google Drive, this PR aims to make your models discoverable from https://huggingface.co/models?pipeline_tag=image-feature-extraction.
I wrote a quick PoC to showcase that you can easily have integration with the 🤗 hub so that you can automatically load the various CoMAE models using from_pretrained
(and push them using push_to_hub
), track download numbers for your models (similar to models in the Transformers library), and have nice model cards on a per-model basis. It leverages the PyTorchModelHubMixin class which allows to inherits these methods.
Usage is as follows:
from models_cpc import MaskedAutoencoderViT
# define model
network = MaskedAutoencoderViT(...)
# equip with weights
model.load_state_dict(...)
# push to the hub
model.push_to_hub("MCG-NJU/comae-base")
# reload
model = MaskedAutoencoderViT.from_pretrained("MCG-NJU/comae-base")
This means people don't need to manually download a checkpoint first in their local environment, it just loads automatically from the hub. Checkpoints could be pushed to https://huggingface.co/MCG-NJU.
Would you be interested in this integration?
Kind regards,
Niels
Note
Please don't merge this PR before pushing a model to the hub :)