setfit
setfit copied to clipboard
Data validation when using differentiable_head
Hi,
Thanks for the great package. I was training a model with the default settings, which went great. I then switched to a differentiable head, which crashed because I used string labels and didn't specify the number of classes. My bad! I should've read the docs.
However, I think it would be nice if it crashed immediately, and not after the fine-tuning stage. I think the dataset parameters could be checked, or even converted automatically:
- We could check the number of classes, and hence set the output layer dimensionality automatically. (unless the differentiable head is created before the call to
.train()) - We could definitely convert the labels to integers automatically, which is exactly what a sklearn LogisticRegression ends up doing under the hood.
- In the worst case, it would be nice to just throw a
ValueErrorbefore starting training if the labels/classes aren't correct (i.e., the number of classes doesn't match or the labels are strings.)
I'd definitely be willing to do a PR, but only if you think it makes sense.