setfit
setfit copied to clipboard
Implement end-to-end differentiable model
The current SetFit implementation combines the embeddings of a frozen body, with the learnable parameters of a logistic regression classifier.
It would be desirable to have a PyTorch implementation that is end-to-end differentiable. This isn't entirely trivial as we'll likely need different learning rates for the body vs the head.
Hi @lewtun,
Just want to know is this issue in progress now? If not, I have some bandwidth to work on it. Or maybe we should wait until #69 merged?
Thanks!
Hey @blakechi if you have bandwidth, I would certainly welcome a PR for this 😍 !
For the API, I've been wondering whether we need something like fastai
's freeze()
and unfreeze()
methods, e.g.:
trainer = SetFitTrainer(...)
# Freeze head
trainer.freeze()
# Do contrastive training
trainer.train(num_epochs=1)
# Unfreeze head
trainer.unfreeze()
# Train end-to-end
trainer.train(num_epochs=1)
The alternative would be to do the freeze/unfreeze automatically for the user and have a single trainer.train()
step - I think the answer will depend a bit on how easy it is to make this work end-to-end.
Another thing to keep in mind is that we'll probably need different learning rates for the head vs body, so the fastai
approach provides a potentially simpler way to decouple these steps by simply passing a different learning_rate
to each trainer.train()
call.
And yes, it's probably best to wait until #69 is merged - let's see what their timeline is :)
Hey @lewtun,
Great! And the API looks good to me. But just a suggestion. Since trainer.freeze
is kind of saying we're going to freeze all of the model (body + head) , what do you think if we branch out freeze
/unfreeze
for head and body individually? And freeze
/unfreeze
are for both. E.g.:
# freeze/unfreeze body only
trainer.freeze_body()
trainer.unfreeze_body()
# freeze/unfreeze head only
trainer.freeze_head()
trainer.unfreeze_head()
# freeze/unfreeze both body and head
trainer.freeze()
trainer.unfreeze()
Okay, I will work on it after #69 merged. ;)
Thanks!
Sure, we can have methods like you suggest if we find it's necessary to freeze/unfreeze the head and the body.
I'm not 100% sure yet if we can get by with just freezing the body - it will probably take some experimentation to find out what works best :)
Hey @blakechi we've just merged #69 so feel free to take a stab at the pure PyTorch model whenever you want (and feel free to ping me here if you have questions!)
Hi @lewtun, good to hear that! I think I will begin with implementing a new class SetFitHead
and then integrate it into SetFitModel
by adding one more argument: use_differentiable_head
. Does this sound good to you?
Hi @lewtun, good to hear that! I think I will begin with implementing a new class
SetFitHead
and then integrate it intoSetFitModel
by adding one more argument:use_differentiable_head
. Does this sound good to you?
That sounds like a good plan to start with. If we can consistently match the performance of the scikit-learn approach, we can eventually deprecate that part and have a pure PyTorch model (my preference since it enables other features like ONNX export in future)
Ok, I will start to implement it!
Hi @lewtun,
I think I'm in the half way (or more) of the implementation and a question raised in my mind. From the API you suggested as below, after we unfreeze the head and fire up the training again, the body will be trained since it's in the end-to-end fashion. Do you think we should also enable users to freeze the body and train the head only as usual?
trainer = SetFitTrainer(...) # Freeze head trainer.freeze() # Do contrastive training trainer.train(num_epochs=1) # Unfreeze head trainer.unfreeze() # Train end-to-end trainer.train(num_epochs=1)
This great news and good question @blakechi !
I think for the first version, enabling the body to be frozen makes sense - especially so we can check that we can reproduce the results from the original logistic regression head :)
Maybe this could be supported with a simple boolean arg like keep_body_frozen
in trainer.unfreeze()
?
Sounds great! Just added that. Start to work on testing.
Hi @lewtun,
Just opened a pull request for this issue! Please have a check.
I haven't trained using the differentiable head, could you suggest a script for training? Thanks!
Can this be closed since the PR was merged?
Yes, I think we can!