vit-pytorch
vit-pytorch copied to clipboard
Update distill.py to include device agnostic code for `distill_mlp` head and `distillation_token`
Since in your code, the distillation_token
and distill_mlp
heads are defined in the DistillWrapper
class, sending the model instance of the DistillableViT
class to GPU does not send the distillation_token
and distill_mlp
head to GPU. Therefore, while training a model using this code, I got a device mismatch error, which made it hard to figure out the source of the error. Finally, the distillation_token
and distill_mlp
turned out to be the culprits as they are not defined in the model class but in the DistillWrapper
class, which is a wrapper of loss function. Therefore, I have suggested the following changes when training a model on GPU: the training code should set the device="cude" if torch.cuda.is_available() else "cpu"
, or the same can be incorporated into the constructor of the DistillWrapper
class.