vit-pytorch icon indicating copy to clipboard operation
vit-pytorch copied to clipboard

Update distill.py to include device agnostic code for `distill_mlp` head and `distillation_token`

Open vivekh2000 opened this issue 7 months ago • 0 comments

Since in your code, the distillation_token and distill_mlp heads are defined in the DistillWrapper class, sending the model instance of the DistillableViT class to GPU does not send the distillation_token and distill_mlp head to GPU. Therefore, while training a model using this code, I got a device mismatch error, which made it hard to figure out the source of the error. Finally, the distillation_token and distill_mlp turned out to be the culprits as they are not defined in the model class but in the DistillWrapper class, which is a wrapper of loss function. Therefore, I have suggested the following changes when training a model on GPU: the training code should set the device="cude" if torch.cuda.is_available() else "cpu", or the same can be incorporated into the constructor of the DistillWrapper class.

vivekh2000 avatar Jul 25 '24 16:07 vivekh2000