vit-pytorch icon indicating copy to clipboard operation
vit-pytorch copied to clipboard

fastai compatibility

Open thenatlog opened this issue 4 years ago • 9 comments

would it be possible to make the distill vit compatible with fastai? both the vanilla vit and efficient vit work fine.

thenatlog avatar Dec 27 '20 06:12 thenatlog

@lwomalley Hi Logan! I know of FastAI but not too familiar with their API

What would it take to be compatible?

lucidrains avatar Dec 28 '20 01:12 lucidrains

In the vanilla vit, the forward pass takes in an image and returns its transformed output. In the distilled vit, the distillwrapper forward pass takes an image + labels and returns the loss. Fastai gets tripped up on this because it automatically supplies the model with labels and is expecting the forward pass to return a transformed output. If there is a way to decouple the forward pass and the loss calculation I suspect it may work.

thenatlog avatar Dec 28 '20 17:12 thenatlog

@lwomalley yup, but the problem is the distillation comes with an auxiliary loss that gets returned. Will FastAI know to add this to the main loss it calculates?

lucidrains avatar Dec 28 '20 17:12 lucidrains

Dont know if can help, but are "vit" and "distilled vit" layers? what is the auxiliary loss? or how it will be in a "normal loop" used?

This is where the loss is calculated, but AFAIK we have the options of callbacks and also we can "replace" this method for the learner itself with a new implementation/different code.

https://github.com/fastai/fastai/blob/master/fastai/learner.py#L172-L173

tyoc213 avatar Dec 28 '20 19:12 tyoc213

@tyoc213 yeah, it won't work, because that line needs to also add the distillation loss https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/distill.py#L120 i could return the logits on the first element of the tuple, and the auxiliary loss on the second, but FastAI will still need to sum the auxiliary loss into the calculated one

lucidrains avatar Dec 28 '20 19:12 lucidrains

mmmm, so the forward is this https://github.com/fastai/fastai/blob/master/fastai/learner.py#L169 in fastai returns a tuple IIRC and we can apply watever we want doing a transform at https://github.com/fastai/fastai/blob/master/fastai/learner.py#L174 which is "call the transforms that answer to after_loss event" and all the transforms have access to learner and other things see https://docs.fast.ai/callback.core.html#Callback

  • after_loss called after the loss has been computed, but before the backward pass. It can be used to add any penalty to the loss (AR or TAR in RNN training for instance).

tyoc213 avatar Dec 28 '20 23:12 tyoc213

@tyoc213 I see, so I'd have to store the auxiliary loss on the instance somewhere? and then in the callback it would be fetched and added to the main loss?

lucidrains avatar Dec 29 '20 00:12 lucidrains

@tyoc213 do you have any code examples of how you are using ViT with FastAI?

lucidrains avatar Dec 29 '20 00:12 lucidrains

Not, but maybe this tiny bit can help https://youtu.be/4w3sEgqDvSo?t=1148 ?

tyoc213 avatar Dec 29 '20 02:12 tyoc213