super-gradients icon indicating copy to clipboard operation
super-gradients copied to clipboard

Change num input channel in resnet during training then i cannot load it again from checkpoint.

Open eric-tc opened this issue 2 years ago • 3 comments

🐛 Describe the bug

Hi i came across this issue.

If I train a resnet18 model

train.py model = models.get(model_name=Models.RESNET18, pretrained_weights="imagenet",num_classes=num_classes,num_input_channels=1)

by changing num_input_channels. Then i cannot load the model again

test.py
best_model = models.get(
            Models.RESNET18,
            num_classes=3,
            num_input_channels=1,
            checkpoint_path=checkpoint
            )

The error is the following.

RuntimeError: Error(s) in loading state_dict for ResNet18: size mismatch for conv1.weight: copying a param with shape torch.Size([64, 1, 7, 7]) from checkpoint, the shape in current model is torch.Size([64, 3, 7, 7]).

this is because in training/models/model_factory.py you load the checkpoint before changing input_channels_dimension.

Chaning the following lines of code before

if num_input_channels is not None and num_input_channels != net.get_input_channels():
     net.replace_input_channels(in_channels=num_input_channels)

if checkpoint_path:
     ckpt_entries = read_ckpt_state_dict(ckpt_path=checkpoint_path).keys()
     load_processing = "processing_params" in ckpt_entries

then i can load the model correctly.

Versions

3.4.1

eric-tc avatar Nov 23 '23 14:11 eric-tc

@eric-tc thank you for opening this issue. We will introduce a fix in the next release. If you don't want to wait untill then - contributions are always welcome.

shaydeci avatar Nov 26 '23 08:11 shaydeci

@eric-tc thank you for opening this issue. We will introduce a fix in the next release. If you don't want to wait untill then - contributions are always welcome.

How to resume from interrupted training in YOLO-NAS? I find it's also a bug,I modified the yaml,but also training from begining

lsm140 avatar Dec 07 '23 06:12 lsm140

Hi, you can just set the resume param.

In the .yaml config

training_hyperparams:
    resume: true

Or, in python code

training_params = {...} # Whatever you have been using
training_params['resume'] = True

Louis-Dupont avatar Dec 25 '23 20:12 Louis-Dupont