setfit icon indicating copy to clipboard operation
setfit copied to clipboard

Data validation when using differentiable_head

Open stephantul opened this issue 1 year ago • 0 comments

Hi,

Thanks for the great package. I was training a model with the default settings, which went great. I then switched to a differentiable head, which crashed because I used string labels and didn't specify the number of classes. My bad! I should've read the docs.

However, I think it would be nice if it crashed immediately, and not after the fine-tuning stage. I think the dataset parameters could be checked, or even converted automatically:

  • We could check the number of classes, and hence set the output layer dimensionality automatically. (unless the differentiable head is created before the call to .train())
  • We could definitely convert the labels to integers automatically, which is exactly what a sklearn LogisticRegression ends up doing under the hood.
  • In the worst case, it would be nice to just throw a ValueError before starting training if the labels/classes aren't correct (i.e., the number of classes doesn't match or the labels are strings.)

I'd definitely be willing to do a PR, but only if you think it makes sense.

stephantul avatar Apr 22 '24 07:04 stephantul