setfit icon indicating copy to clipboard operation
setfit copied to clipboard

Implement end-to-end differentiable model

Open lewtun opened this issue 2 years ago • 12 comments

The current SetFit implementation combines the embeddings of a frozen body, with the learnable parameters of a logistic regression classifier.

It would be desirable to have a PyTorch implementation that is end-to-end differentiable. This isn't entirely trivial as we'll likely need different learning rates for the body vs the head.

lewtun avatar Jul 08 '22 10:07 lewtun

Hi @lewtun,

Just want to know is this issue in progress now? If not, I have some bandwidth to work on it. Or maybe we should wait until #69 merged?

Thanks!

blakechi avatar Oct 07 '22 06:10 blakechi

Hey @blakechi if you have bandwidth, I would certainly welcome a PR for this 😍 !

For the API, I've been wondering whether we need something like fastai's freeze() and unfreeze() methods, e.g.:

trainer = SetFitTrainer(...)

# Freeze head
trainer.freeze()

# Do contrastive training
trainer.train(num_epochs=1)

# Unfreeze head
trainer.unfreeze()

# Train end-to-end
trainer.train(num_epochs=1)

The alternative would be to do the freeze/unfreeze automatically for the user and have a single trainer.train() step - I think the answer will depend a bit on how easy it is to make this work end-to-end.

Another thing to keep in mind is that we'll probably need different learning rates for the head vs body, so the fastai approach provides a potentially simpler way to decouple these steps by simply passing a different learning_rate to each trainer.train() call.

And yes, it's probably best to wait until #69 is merged - let's see what their timeline is :)

lewtun avatar Oct 07 '22 08:10 lewtun

Hey @lewtun, Great! And the API looks good to me. But just a suggestion. Since trainer.freeze is kind of saying we're going to freeze all of the model (body + head) , what do you think if we branch out freeze/unfreeze for head and body individually? And freeze/unfreeze are for both. E.g.:

# freeze/unfreeze body only
trainer.freeze_body()
trainer.unfreeze_body()

# freeze/unfreeze head only
trainer.freeze_head()
trainer.unfreeze_head()

# freeze/unfreeze both body and head
trainer.freeze()
trainer.unfreeze()

Okay, I will work on it after #69 merged. ;)

Thanks!

blakechi avatar Oct 08 '22 06:10 blakechi

Sure, we can have methods like you suggest if we find it's necessary to freeze/unfreeze the head and the body.

I'm not 100% sure yet if we can get by with just freezing the body - it will probably take some experimentation to find out what works best :)

lewtun avatar Oct 10 '22 10:10 lewtun

Hey @blakechi we've just merged #69 so feel free to take a stab at the pure PyTorch model whenever you want (and feel free to ping me here if you have questions!)

lewtun avatar Oct 11 '22 13:10 lewtun

Hi @lewtun, good to hear that! I think I will begin with implementing a new class SetFitHead and then integrate it into SetFitModel by adding one more argument: use_differentiable_head. Does this sound good to you?

blakechi avatar Oct 12 '22 05:10 blakechi

Hi @lewtun, good to hear that! I think I will begin with implementing a new class SetFitHead and then integrate it into SetFitModel by adding one more argument: use_differentiable_head. Does this sound good to you?

That sounds like a good plan to start with. If we can consistently match the performance of the scikit-learn approach, we can eventually deprecate that part and have a pure PyTorch model (my preference since it enables other features like ONNX export in future)

lewtun avatar Oct 12 '22 08:10 lewtun

Ok, I will start to implement it!

blakechi avatar Oct 13 '22 06:10 blakechi

Hi @lewtun,

I think I'm in the half way (or more) of the implementation and a question raised in my mind. From the API you suggested as below, after we unfreeze the head and fire up the training again, the body will be trained since it's in the end-to-end fashion. Do you think we should also enable users to freeze the body and train the head only as usual?

trainer = SetFitTrainer(...)

# Freeze head
trainer.freeze()

# Do contrastive training
trainer.train(num_epochs=1)

# Unfreeze head
trainer.unfreeze()

# Train end-to-end
trainer.train(num_epochs=1)

blakechi avatar Oct 17 '22 07:10 blakechi

This great news and good question @blakechi !

I think for the first version, enabling the body to be frozen makes sense - especially so we can check that we can reproduce the results from the original logistic regression head :)

Maybe this could be supported with a simple boolean arg like keep_body_frozen in trainer.unfreeze()?

lewtun avatar Oct 17 '22 08:10 lewtun

Sounds great! Just added that. Start to work on testing.

blakechi avatar Oct 18 '22 06:10 blakechi

Hi @lewtun,

Just opened a pull request for this issue! Please have a check.

I haven't trained using the differentiable head, could you suggest a script for training? Thanks!

blakechi avatar Oct 19 '22 04:10 blakechi

Can this be closed since the PR was merged?

PhilipMay avatar Nov 01 '22 16:11 PhilipMay

Yes, I think we can!

lewtun avatar Nov 01 '22 16:11 lewtun