setfit
setfit copied to clipboard
Defining arguments names to avoid issues with positional args
Fixes #418, #291, #355
Some models that don't use token_type_ids
were presenting different results when you run it using ONNX. This issue was due to not explicitly saying the argument names during the call, which makes the model assumes the argument was a different one.
@tomaarsen
@tomaarsen should we include the softmax operation inside the ONNX model? Or returning the logits is the desired behavior?
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.
I quite like the simplicity of this fix, and I do hope that this indeed resolves the problem. However, the tests seem to be problematic:
onnxruntime.capi.onnxruntime_pybind11_state.InvalidArgument: [ONNXRuntimeError] : 2 : INVALID_ARGUMENT : Invalid Feed Input Name:token_type_ids
Perhaps ONNX doesn't want token_type_ids
to be passed, even if it is None
? Then a solution might be:
inputs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
}
if token_type_ids is not None:
inputs["token_type_ids"] = token_type_ids
hidden_states = self.model_body(**inputs)
- Tom Aarsen
Don't run the CI yet, Im still making some tests
@tomaarsen the issue with the tests is a little bit harder: for models that uses token_type_ids
the tests run like a charm. However, for models that don't use token_type_ids
the ONNX generates the graph without the token_type_ids
as input. I tried your suggestion, but no success.
It is also an issue on transformers https://discuss.huggingface.co/t/deberta-v2-onnx-with-pipeline-does-not-work/35748
What I'm trying to do is to force the token_type_ids
to appear on the graph, let's see if it works or not.
Do you have any other suggestion?
Hmm, that is tricky. I'm not extremely familiar with the inner workings of ONNX, so I don't have great suggestions I'm afraid.
- Tom Aarsen
I will try more things today, but by now, I'm showing a message when the model doesn't use the token_type_ids
, just to validate if the whole flow is working as expected.
@tomaarsen I found a way to keep the same interface we are using today and force the token_type_ids
to appear. It can't be an optional argument, because one time you have the ONNX graph defined, you must fill all the inputs used during the export, which means that we need to always pass the token_type_ids
even if it is not used by the model.
On one hand it can look like: "Why keep an argument that is not used by some models?", but the answer is the generalization. Keeping the parameter for all the cases make it possible to have an export code that is general enough for all the model.
Let me know WDYT about the solution. Maybe in the future we can work to create better interfaces for the export, but I tried to keep this PR as simple as possible.
@tomaarsen can we merge this? If yes, I will solve the conflict
Apologies for not responding sooner, I've been a bit busy and ONNX wasn't very high on my TODO list. Do you suspect that this PR indeed fixes the reported discrepancies? E.g. does the script from #291 behave as expected now? The fix seems so odd to me, haha.
Also, models with a differentiable head are trained a bit differently in SetFit v1.0.0, i.e. no more freeze
and unfreeze
calls, and only calling trainer.train()
once.
- Tom Aarsen
I will run the script with this branch, but on my tests I was seeing discrepancy between the results too, and after the fix it worked and returned the same scores.
Give me the weekend to see the code for v1.0.0 and I can answer here.