AdvancedLiterateMachinery icon indicating copy to clipboard operation
AdvancedLiterateMachinery copied to clipboard

Where do I load pre-trained DiT and GiT in the VGT code?

Open YuNie24 opened this issue 1 year ago • 5 comments
trafficstars

I downloaded 2 weights for finetining VGT.

When finetuning VGT, Where should I specify the pre-trained ViT and DiT weight file paths? If possible, please specify which file the 2 weight paths are loaded from.

YuNie24 avatar Jan 04 '24 05:01 YuNie24

I have the same issue, can't find out exactly where the GiT and ViT models are loaded. Can you please help me with this? @yashsandansing @alibaba-oss @Wangsherpa

bavo96 avatar Apr 12 '24 05:04 bavo96

There exist command line arguments including --opts where you can specify additional arguments in form of key value pairs like this
MODEL.WEIGHTS <path to modelweights> or in config file (choosen based on your dataset including D4LA, docbank, doclaynet, publaynet) present in Configs/cascade directory, you can directly specify model weights as follows
WEIGHTS: "<yourdirectorystructure>/<weightfilename>"

NOTE: you can manually download weights for DiT's (base or large) model from DiT repository and for layoutlm's pytorchmodel.bin from huggingface

Harsh19012003 avatar Jun 16 '24 06:06 Harsh19012003

I have the same issue, can't find out exactly where the GiT and ViT models are loaded. Can you please help me with this? @yashsandansing @alibaba-oss @Wangsherpa

Hi @bavo96, were you able to figure this out?

ritutweets46 avatar Oct 31 '24 09:10 ritutweets46

Hi @ritutweets46, based on my current understanding, the path to change the ViT pre-trained model is in the MyDetectionCheckpointer class within VGTcheckpointer.py

class MyDetectionCheckpointer(DetectionCheckpointer):
    def _load_model(self, checkpoint: Any) -> _IncompatibleKeys:
        ...
        DiT_checkpoint_state_dict = torch.load("/path/dit-base-224-p16-500k-62d53a.pth", map_location=torch.device("cpu"))["model"]
        ...
    

and the path to change the GiT pre-trained model is in the VGTTrainer class for training

class VGTTrainer(TrainerBase):
    ...
    def resume_or_load(self, resume=True):
        ...
        self.checkpointer.resume_or_load(self.cfg.MODEL.WEIGHTS, resume=resume)
        ...

and also in the DefaultPredictor class for inference

class DefaultPredictor:
    def __init__(self, cfg):
        ...
        checkpointer.load(cfg.MODEL.WEIGHTS)
        ...

Please feel free to correct me if I’m mistaken :D

bavo96 avatar Nov 01 '24 01:11 bavo96

Thank you @bavo96 !

ritutweets46 avatar Nov 02 '24 14:11 ritutweets46