overhaul-distillation icon indicating copy to clipboard operation
overhaul-distillation copied to clipboard

Does t_net also need to be trained?

Open TouchSkyWf opened this issue 2 years ago • 1 comments

d_net.train() d_net.module.s_net.train() d_net.module.t_net.train()

Hi,

I noticed that the Teacher model is also trained. In the general distillation process, shouldn't the teacher be in inference mode? Both in training mode will cause the training speed to be extremely slow.

Shouldn't the teacher model be in inference mode?

TouchSkyWf avatar Jun 22 '22 02:06 TouchSkyWf

Hi @TouchSkyWf Thank you for your interest in our research.

I set the teacher network in training mode. But it doesn't mean the teacher network is trained in the distillation process. Because the teacher's feature is detached in loss computation, backward propagation doesn't calculate gradients for the teacher network, which means no slowdown or additional computation due to teacher network training. https://github.com/clovaai/overhaul-distillation/blob/76344a84a7ce23c894f41a2e05b866c9b73fd85a/CIFAR-100/distiller.py#L69-L70 I want to emphasize that the training mode in PyTorch does nothing on gradient calculation. It is used to set networks' modules to training or inference mode.

The training mode is set for BatchNorm in the teacher network. Because BatchNorm uses running_mean and running_var in inference mode, we thought the mode of BatchNorm might affect distillation performance. In Table 8, we observe that BatchNorm in training mode is better than in inference mode. So, we set the teacher network to training mode for BatchNorm in the teacher network.

I think other methods such as with torch.no_grad(): on the teacher forward pass might increase the readability of our code. But, I didn't know the method other than detach() when we wrote the paper.

Best Byeongho Heo

bhheo avatar Jul 12 '22 09:07 bhheo