DenoiseCompression icon indicating copy to clipboard operation
DenoiseCompression copied to clipboard

Test: load state_dict when inference

Open hamedsteiner opened this issue 2 years ago • 1 comments

Hi, in loading the checkpoint, in line 118 of utils.py (https://github.com/felixcheng97/DenoiseCompression/blob/main/CompressAI/codes/utils/util.py#L118) shouldn't it change to m = architectures[model].from_state_dict(checkpoint, opt)?

Not having this change will cause error

hamedsteiner avatar Sep 19 '22 21:09 hamedsteiner

Thanks for your question.

I have double checked that our test code should work fine during checkpoint/state_dict loading.

In our code, we define checkpoint as a dictionary that is stored during training, containing several things like 'epoch', 'iter', 'state_dict', 'loss', 'optimizer', 'aux_optimizer', 'lr_scheduler', 'aux_lr_scheduler' as shown below: https://github.com/felixcheng97/DenoiseCompression/blob/009c7539638ab85b8d21d174c8f87934bc2264cc/CompressAI/codes/train.py#L361-L373

We define state_dict as the parameter weights of the network, which is simply model.state_dict() as shown below: https://github.com/felixcheng97/DenoiseCompression/blob/009c7539638ab85b8d21d174c8f87934bc2264cc/CompressAI/codes/train.py#L364

According to these, we have two different if branches to handle a training_checkpoint or a (updated_)state_dict: https://github.com/felixcheng97/DenoiseCompression/blob/009c7539638ab85b8d21d174c8f87934bc2264cc/CompressAI/codes/utils/util.py#L117-L120

The released pretrained models are charaterized as updated_state_dict, so in our config yml file, we comment out the 'checkpoint' key (the 'update' key can be either true or false). If you wish to test a training_checkpoint, you should uncomment the 'checkpoint' key and set the 'update' key as true. https://github.com/felixcheng97/DenoiseCompression/blob/009c7539638ab85b8d21d174c8f87934bc2264cc/CompressAI/codes/conf/test/multiscale-decomp_sidd_mse_q1.yml#L27-L30

In case that you still have problem to load the training_checkpoint or the updated_state_dict, please provide more details here about what is your config (yml) setting that causes the error.

Hope these help.

felixcheng97 avatar Sep 20 '22 09:09 felixcheng97