GGHL
GGHL copied to clipboard
Transfer Learning
Hello, I have a question regarding training a custom dataset.
How can I transfer learning of some specific classes from the pre-trained weights (e.g. dota) to my custom training if my custom classes are different from the pre-trained classes?
Best regards and Thank you
Hello, I have a question regarding training a custom dataset.
How can I transfer learning of some specific classes from the pre-trained weights (e.g. dota) to my custom training if my custom classes are different from the pre-trained classes?
Best regards and Thank you
Hi, If you want to use the model we trained on the DOTA dataset as pre-training instead of the initial pre-training model 1) Load the model trained on DOTA, and index the weights of other layers except for the last prediction layer according to the dictionary key of the PyTorch model (that is, do not load the weights of the 15-class classification layer of the DOT dataset). Then, you need to assign weights to the defined model, and the classification layer uses randomly initialized weights. E.g, weight_path = os.path.join(os.path.split(weight_path)[0], "last.pt") chkpt = torch.load(weight_path, map_location=self.device) chkpt['model'] = {k: v for k, v in chkpt['model'].items() if k in model_dict} #Load weights according to key-value, you need to modify the above line that does not load the weights of classification model_dict.update(chkpt['model']) ##Update model weights self.model.load_state_dict(model_dict) 2) Freeze the gradient updates of other layers except for the last classification layer, and use your data to train only the last layer for a while (I guess it doesn't take much, just converge). 3) Unfreeze the gradient updates of other layers, and use your data to fine-tune the model (the learning rate should not be too high) for some time until it converges. The PyTorch official model freezing tutorial is https://pytorch.org/tutorials/prototype/torchscript_freezing.html?highlight=freeze Thank you.