CAT icon indicating copy to clipboard operation
CAT copied to clipboard

`restore` options to resume `distill.py`

Open youjin-c opened this issue 2 years ago • 7 comments

Hello,

I got this error for onnx exporting.

Traceback (most recent call last):
  File "/home/ubuntu/CAT/onnx_export.py", line 13, in <module>
    exporter = Exporter()
  File "/home/ubuntu/CAT/onnx_exporter.py", line 59, in __init__
    model.netG_student.load_state_dict(
  File "/home/ubuntu/.local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1051, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for InceptionGenerator:
	size mismatch for down_sampling.7.weight: copying a param with shape torch.Size([234, 16, 3, 3]) from checkpoint, the shape in current model is torch.Size([230, 16, 3, 3]).
	size mismatch for down_sampling.7.bias: copying a param with shape torch.Size([234]) from checkpoint, the shape in current model is torch.Size([230]).
...
and so many mismatches

It is so weird since I checked it worked before and actually exported models. Let me share my commands for distilling and exporting.

!python distill.py --dataroot database/face2smile \
  --dataset_mode unaligned \
  --distiller inception \
  --gan_mode lsgan \
  --log_dir logs/cycle_gan/face2smile/inception/student/2p6B \
  --restore_teacher_G_path logs/cycle_gan/face2smile/inception/teacher/checkpoints/best_B_net_G_A.pth \
  --restore_pretrained_G_path logs/cycle_gan/face2smile/inception/teacher/checkpoints/best_B_net_G_A.pth \
  --restore_D_path logs/cycle_gan/face2smile/inception/teacher/checkpoints/best_B_net_D_A.pth \
  --real_stat_path real_stat/face2smile_B.npz \
  --nepochs 500 --nepochs_decay 500 \
  --teacher_netG inception_9blocks --student_netG inception_9blocks \
  --pretrained_ngf 64 --teacher_ngf 64 --student_ngf 20 \
  --ndf 64 \
  --num_threads 80 \
  --eval_batch_size 4 \
  --batch_size 80 \
  --gpu_ids 0,1,2,3 \
  --norm_affine \
  --norm_affine_D \
  --channels_reduction_factor 6 \
  --kernel_sizes 1 3 5 \
  --lambda_distill 1.0 \
  --lambda_recon 5 \
  --prune_cin_lb 16 \
  --target_flops 2.6e9 \
  --distill_G_loss_type ka \
  --save_epoch_freq 1 \
  --save_latest_freq 500 \
  --norm_student batch \
  --padding_type_student zero \
  --norm_affine_student \
  --norm_track_running_stats_student
!python3 onnx_export.py --dataroot database/face2smile \
  --log_dir onnx_files/cycle_gan/face2smile/inception/student/2p6B \
  --restore_teacher_G_path logs/cycle_gan/face2smile/inception/teacher/checkpoints/best_A_net_G_A.pth \
  --pretrained_student_G_path logs/cycle_gan/face2smile/inception/student/2p6B/checkpoints/best_net_G.pth \
  --real_stat_path real_stat/face2smile_B.npz \
   --dataset_mode unaligned \
  --pretrained_ngf 64 --teacher_ngf 64 --student_ngf 20 \
  --gpu_ids 0 \
  --norm_affine \
  --channels_reduction_factor 6 \
  --kernel_sizes 1 3 5 \
  --prune_cin_lb 16 \
  --target_flops 2.6e9 \
  --ndf 64 \
  --batch_size 8 \
  --eval_batch_size 2 \
  --num_threads 8 \
  --norm_affine_D \
  --teacher_netG inception_9blocks --student_netG inception_9blocks \
  --distiller inception \
  --gan_mode lsgan \
  --norm_student batch \
  --padding_type_student zero \
  --norm_affine_student \
  --norm_track_running_stats_student

youjin-c avatar Aug 30 '22 20:08 youjin-c

I read this issue. #11 I checked my branch is tutorial, and my error happened in onnx_export, not with distll.py

youjin-c avatar Aug 30 '22 21:08 youjin-c

So I fixed the issue.

  1. For some reason, I got an error when onnx_export.py had multiple gpu_ids. When I set --gpu_ids 0, I could fix the error that it cannot find some element from the dictionary.
  2. I accidentally resumed distll.py with the wrong file (I referred to one B file while others were A files). when I matched to A models for all .pth files, I could export to an onnx file without the error above.

However, I still am not sure what is the right options to resume distilling. I referred to some former issues like #11, but it didn't properly load the networks. When I saw the fid values and others, the distilling just started from scratch, not loading the model referred.

So, the question is, may I ask the proper options to resume distill.py?

youjin-c avatar Aug 31 '22 07:08 youjin-c

especially these three options restore_teacher_G_path, restore_pretrained_G_path, restore_student_G_path are confusing. Do I need to put the best student G's(or student G that I want to restore) for both restore_pretrained_G_path and restore_student_G_path??

youjin-c avatar Aug 31 '22 16:08 youjin-c

Hi @youjinChung, thanks for your interest in our work. You could try to check this function. I am not sure how did you specify the restore_student_G_path, which is the student checkpoint you need to use to resume distillation, and how did you specify the prune_continue flag. For the three options,

Hope the above is clear now.

deJQK avatar Aug 31 '22 22:08 deJQK

Hi @youjinChung, thanks again for your interests. Just checking whether you could run the code successfully.

alanspike avatar Sep 06 '22 16:09 alanspike

Hello, I am training the teacher model now, let me try when I check to distill again and come back to this thread.

youjin-c avatar Sep 06 '22 19:09 youjin-c

Hello,

So I checked with restore_student_G_path and prune_continue options, but I don't see the distiller resumes model from the latest fid score. I expected the resuming actually hands the values like G_gan, G_distill, G_recon, ... and fid scores as well, it still seems initialized. Please let me know if my understanding is incorrect.

Below are my command for resuming and logs before and after resuming.

!python distill.py --dataroot database/face2smile \
  --dataset_mode unaligned \
  --distiller inception \
  --gan_mode lsgan \
  --log_dir logs/cycle_gan/face2smile/student_512_resume \
  --restore_teacher_G_path logs/cycle_gan/face2smile/teacher_512/checkpoints/170_net_G_A.pth \
  --real_stat_path real_stat_512/face2smile_B.npz \
  --teacher_netG inception_9blocks --student_netG inception_9blocks \
  --pretrained_ngf 64 --teacher_ngf 64 --student_ngf 20 \
  --ndf 64 \
  --num_threads 32 \
  --eval_batch_size 4 \
  --batch_size 40 \
  --gpu_ids 0,1,2,3 \
  --norm_affine \
  --norm_affine_D \
  --channels_reduction_factor 6 \
  --kernel_sizes 1 3 5 \
  --lambda_distill 1.0 \
  --lambda_recon 5 \
  --prune_cin_lb 16 \
  --target_flops 2.6e9 \
  --distill_G_loss_type ka \
  --save_epoch_freq 1 \
  --save_latest_freq 500 \
  --norm_student batch \
  --padding_type_student zero \
  --norm_affine_student \
  --norm_track_running_stats_student \
  --preprocess scale_width --load_size 512 \
  --nepochs 0 --nepochs_decay 325 \
  --epoch_base 675 --iter_base 85600 \
  --restore_student_G_path logs/cycle_gan/face2smile/student_512/checkpoints/674_net_G.pth \
  --restore_A_path logs/cycle_gan/face2smile/student_512/checkpoints/674_net_A \
  --restore_O_path logs/cycle_gan/face2smile/student_512/checkpoints/674_optim \
  --restore_D_path logs/cycle_gan/face2smile/student_512/checkpoints/674_net_D.pth \
  --prune_continue
(epoch: 674, iters: 85500, time: 1.136) G_gan: 1.037 G_distill: -15.940 G_recon: 0.304 D_fake: 0.008 D_real: 0.005 
###(Evaluate epoch: 674, iters: 85500, time: 63.606) fid: 75.132 fid-mean: 75.769 fid-best: 70.321 
Saving the latest model (epoch 674, total_steps 85500)
End of epoch 674 / 1000 	 Time Taken: 228.54 sec
###(Evaluate epoch: 674, iters: 85599, time: 63.101) fid: 71.786 fid-mean: 74.345 fid-best: 70.321 
Saving the model at the end of epoch 674, iters 85599
learning rate = 0.0001301

This is before resuming

All networks loaded.
netG student FLOPs: 2580709376; down sampling: 873725952; features: 61767680; up sampling: 1645215744.
(epoch: 675, iters: 85600, time: 13.963) G_gan: 0.795 G_distill: -12.844 G_recon: 2.335 D_fake: 0.257 D_real: 0.004 
###(Evaluate epoch: 675, iters: 85600, time: 54.235) fid: 435.564 fid-mean: 435.564 fid-best: 435.564 
Saving the latest model (epoch 675, total_steps 85600)
(epoch: 675, iters: 85700, time: 1.117) G_gan: 1.066 G_distill: -15.767 G_recon: 1.448 D_fake: 0.016 D_real: 0.012 
End of epoch 675 / 999 	 Time Taken: 226.41 sec
###(Evaluate epoch: 675, iters: 85727, time: 61.527) fid: 246.713 fid-mean: 341.138 fid-best: 246.713 
Saving the model at the end of epoch 675, iters 85727
learning rate = 0.0001988

and this is after resuming. Apparently, the network is not loaded properly. When I check the eval images, images are initialized with gray.

youjin-c avatar Sep 16 '22 19:09 youjin-c