setfit icon indicating copy to clipboard operation
setfit copied to clipboard

Inconsistent ONNX Export with Differentiable Head

Open radandreicristian opened this issue 1 year ago • 5 comments

Hello,

I am trying to fine-tune SetFit for a multi-class classification problem.

Everything is smooth until exporting to Onnx. The head is not exported correctly, so when loading the model with Onnx, the predictions are the outputs of some previous layer.

For example, I showcased this on a dummy dataset with 3 classes.

See https://colab.research.google.com/drive/19EESqbIDwD5FOI2Ufx22txFZ8qADq2xj?authuser=1#scrollTo=vl6AvQKqtU-9

This differs from the behaviour in the example notebook, where the LogisticRegression head is used.

Any directions would be appreciated; Otherwise, I would happily contribute with an MR if anyone can spot the issue.

radandreicristian avatar Sep 06 '23 14:09 radandreicristian

The output is the raw logits per class for each sample, and all you need to do then is a softmax/argmax to get the class labels.

Is this intended?

radandreicristian avatar Sep 07 '23 06:09 radandreicristian

Hello!

I'm aware that there's some issues with exporting ONNX sadly. Hopefully I will have some time in the future to refactor the exporting to a more consistent approach instead. Thank you for raising this issue & for providing a Google Colab! It'll be helpful for certain.

  • Tom Aarsen

tomaarsen avatar Nov 24 '23 13:11 tomaarsen

Hey @tomaarsen, can I work on this?

pedrogengo avatar Nov 27 '23 09:11 pedrogengo

After a long debugging, I found the issue. As we are not defining the argument names on https://github.com/huggingface/setfit/blame/cbc01ec402e86ca04e5e40e9bce7f618f3c2946c/src/setfit/exporters/onnx.py#L50

transformers library assumes that the third argument represents position_idsand not token_type_ids (https://github.com/huggingface/transformers/blob/b09912c8f452ac485933ac0f86937aa01de3c398/src/transformers/models/mpnet/modeling_mpnet.py#L515-L525). The fix for this issue it simply define the arguments name.

I will open a PR to fix this

pedrogengo avatar Nov 27 '23 12:11 pedrogengo

https://colab.research.google.com/drive/19xE4WdxqGLLZOSanycYfUzbcAxFgpzuR?usp=sharing

Here is my colab that I used to debug

pedrogengo avatar Nov 27 '23 12:11 pedrogengo