CAT
CAT copied to clipboard
`restore` options to resume `distill.py`
Hello,
I got this error for onnx exporting.
Traceback (most recent call last):
File "/home/ubuntu/CAT/onnx_export.py", line 13, in <module>
exporter = Exporter()
File "/home/ubuntu/CAT/onnx_exporter.py", line 59, in __init__
model.netG_student.load_state_dict(
File "/home/ubuntu/.local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1051, in load_state_dict
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for InceptionGenerator:
size mismatch for down_sampling.7.weight: copying a param with shape torch.Size([234, 16, 3, 3]) from checkpoint, the shape in current model is torch.Size([230, 16, 3, 3]).
size mismatch for down_sampling.7.bias: copying a param with shape torch.Size([234]) from checkpoint, the shape in current model is torch.Size([230]).
...
and so many mismatches
It is so weird since I checked it worked before and actually exported models. Let me share my commands for distilling and exporting.
!python distill.py --dataroot database/face2smile \
--dataset_mode unaligned \
--distiller inception \
--gan_mode lsgan \
--log_dir logs/cycle_gan/face2smile/inception/student/2p6B \
--restore_teacher_G_path logs/cycle_gan/face2smile/inception/teacher/checkpoints/best_B_net_G_A.pth \
--restore_pretrained_G_path logs/cycle_gan/face2smile/inception/teacher/checkpoints/best_B_net_G_A.pth \
--restore_D_path logs/cycle_gan/face2smile/inception/teacher/checkpoints/best_B_net_D_A.pth \
--real_stat_path real_stat/face2smile_B.npz \
--nepochs 500 --nepochs_decay 500 \
--teacher_netG inception_9blocks --student_netG inception_9blocks \
--pretrained_ngf 64 --teacher_ngf 64 --student_ngf 20 \
--ndf 64 \
--num_threads 80 \
--eval_batch_size 4 \
--batch_size 80 \
--gpu_ids 0,1,2,3 \
--norm_affine \
--norm_affine_D \
--channels_reduction_factor 6 \
--kernel_sizes 1 3 5 \
--lambda_distill 1.0 \
--lambda_recon 5 \
--prune_cin_lb 16 \
--target_flops 2.6e9 \
--distill_G_loss_type ka \
--save_epoch_freq 1 \
--save_latest_freq 500 \
--norm_student batch \
--padding_type_student zero \
--norm_affine_student \
--norm_track_running_stats_student
!python3 onnx_export.py --dataroot database/face2smile \
--log_dir onnx_files/cycle_gan/face2smile/inception/student/2p6B \
--restore_teacher_G_path logs/cycle_gan/face2smile/inception/teacher/checkpoints/best_A_net_G_A.pth \
--pretrained_student_G_path logs/cycle_gan/face2smile/inception/student/2p6B/checkpoints/best_net_G.pth \
--real_stat_path real_stat/face2smile_B.npz \
--dataset_mode unaligned \
--pretrained_ngf 64 --teacher_ngf 64 --student_ngf 20 \
--gpu_ids 0 \
--norm_affine \
--channels_reduction_factor 6 \
--kernel_sizes 1 3 5 \
--prune_cin_lb 16 \
--target_flops 2.6e9 \
--ndf 64 \
--batch_size 8 \
--eval_batch_size 2 \
--num_threads 8 \
--norm_affine_D \
--teacher_netG inception_9blocks --student_netG inception_9blocks \
--distiller inception \
--gan_mode lsgan \
--norm_student batch \
--padding_type_student zero \
--norm_affine_student \
--norm_track_running_stats_student
I read this issue. #11
I checked my branch is tutorial
, and my error happened in onnx_export
, not with distll.py
So I fixed the issue.
- For some reason, I got an error when onnx_export.py had multiple
gpu_ids.
When I set--gpu_ids 0
, I could fix the error that it cannot find some element from the dictionary. - I accidentally resumed distll.py with the wrong file (I referred to one B file while others were A files). when I matched to A models for all
.pth
files, I could export to an onnx file without the error above.
However, I still am not sure what is the right options to resume distilling. I referred to some former issues like #11, but it didn't properly load the networks. When I saw the fid values and others, the distilling just started from scratch, not loading the model referred.
So, the question is, may I ask the proper options to resume distill.py
?
especially these three options restore_teacher_G_path
, restore_pretrained_G_path
, restore_student_G_path
are confusing.
Do I need to put the best student G's(or student G that I want to restore) for both restore_pretrained_G_path
and restore_student_G_path
??
Hi @youjinChung, thanks for your interest in our work. You could try to check this function. I am not sure how did you specify the restore_student_G_path
, which is the student checkpoint you need to use to resume distillation, and how did you specify the prune_continue
flag. For the three options,
-
restore_teacher_G_path
is the checkpoint you need to load into the teacher for distillation, as well as to determine the student's architecture to build the student model (pruning); -
restore_pretrained_G_path
is the checkpoint you need to load into the pretrained network which is used to determine the intialization for the student training with knowledge distillation, and we are using the pretrained teacher supernet for this model in the method of the paper, but you can freely choose any other initial pretrained model; -
restore_student_G_path
is the checkpoint you need to load into the already pruned student for resuming distillation.
Hope the above is clear now.
Hi @youjinChung, thanks again for your interests. Just checking whether you could run the code successfully.
Hello, I am training the teacher model now, let me try when I check to distill again and come back to this thread.
Hello,
So I checked with restore_student_G_path
and prune_continue
options, but I don't see the distiller resumes model from the latest fid
score.
I expected the resuming actually hands the values like G_gan, G_distill, G_recon, ...
and fid
scores as well, it still seems initialized. Please let me know if my understanding is incorrect.
Below are my command for resuming and logs before and after resuming.
!python distill.py --dataroot database/face2smile \
--dataset_mode unaligned \
--distiller inception \
--gan_mode lsgan \
--log_dir logs/cycle_gan/face2smile/student_512_resume \
--restore_teacher_G_path logs/cycle_gan/face2smile/teacher_512/checkpoints/170_net_G_A.pth \
--real_stat_path real_stat_512/face2smile_B.npz \
--teacher_netG inception_9blocks --student_netG inception_9blocks \
--pretrained_ngf 64 --teacher_ngf 64 --student_ngf 20 \
--ndf 64 \
--num_threads 32 \
--eval_batch_size 4 \
--batch_size 40 \
--gpu_ids 0,1,2,3 \
--norm_affine \
--norm_affine_D \
--channels_reduction_factor 6 \
--kernel_sizes 1 3 5 \
--lambda_distill 1.0 \
--lambda_recon 5 \
--prune_cin_lb 16 \
--target_flops 2.6e9 \
--distill_G_loss_type ka \
--save_epoch_freq 1 \
--save_latest_freq 500 \
--norm_student batch \
--padding_type_student zero \
--norm_affine_student \
--norm_track_running_stats_student \
--preprocess scale_width --load_size 512 \
--nepochs 0 --nepochs_decay 325 \
--epoch_base 675 --iter_base 85600 \
--restore_student_G_path logs/cycle_gan/face2smile/student_512/checkpoints/674_net_G.pth \
--restore_A_path logs/cycle_gan/face2smile/student_512/checkpoints/674_net_A \
--restore_O_path logs/cycle_gan/face2smile/student_512/checkpoints/674_optim \
--restore_D_path logs/cycle_gan/face2smile/student_512/checkpoints/674_net_D.pth \
--prune_continue
(epoch: 674, iters: 85500, time: 1.136) G_gan: 1.037 G_distill: -15.940 G_recon: 0.304 D_fake: 0.008 D_real: 0.005
###(Evaluate epoch: 674, iters: 85500, time: 63.606) fid: 75.132 fid-mean: 75.769 fid-best: 70.321
Saving the latest model (epoch 674, total_steps 85500)
End of epoch 674 / 1000 Time Taken: 228.54 sec
###(Evaluate epoch: 674, iters: 85599, time: 63.101) fid: 71.786 fid-mean: 74.345 fid-best: 70.321
Saving the model at the end of epoch 674, iters 85599
learning rate = 0.0001301
This is before resuming
All networks loaded.
netG student FLOPs: 2580709376; down sampling: 873725952; features: 61767680; up sampling: 1645215744.
(epoch: 675, iters: 85600, time: 13.963) G_gan: 0.795 G_distill: -12.844 G_recon: 2.335 D_fake: 0.257 D_real: 0.004
###(Evaluate epoch: 675, iters: 85600, time: 54.235) fid: 435.564 fid-mean: 435.564 fid-best: 435.564
Saving the latest model (epoch 675, total_steps 85600)
(epoch: 675, iters: 85700, time: 1.117) G_gan: 1.066 G_distill: -15.767 G_recon: 1.448 D_fake: 0.016 D_real: 0.012
End of epoch 675 / 999 Time Taken: 226.41 sec
###(Evaluate epoch: 675, iters: 85727, time: 61.527) fid: 246.713 fid-mean: 341.138 fid-best: 246.713
Saving the model at the end of epoch 675, iters 85727
learning rate = 0.0001988
and this is after resuming. Apparently, the network is not loaded properly. When I check the eval images, images are initialized with gray.