setfit icon indicating copy to clipboard operation
setfit copied to clipboard

Defining arguments names to avoid issues with positional args

Open pedrogengo opened this issue 1 year ago • 11 comments

Fixes #418, #291, #355

Some models that don't use token_type_ids were presenting different results when you run it using ONNX. This issue was due to not explicitly saying the argument names during the call, which makes the model assumes the argument was a different one.

@tomaarsen

pedrogengo avatar Nov 27 '23 13:11 pedrogengo

@tomaarsen should we include the softmax operation inside the ONNX model? Or returning the logits is the desired behavior?

pedrogengo avatar Nov 27 '23 13:11 pedrogengo

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

I quite like the simplicity of this fix, and I do hope that this indeed resolves the problem. However, the tests seem to be problematic: onnxruntime.capi.onnxruntime_pybind11_state.InvalidArgument: [ONNXRuntimeError] : 2 : INVALID_ARGUMENT : Invalid Feed Input Name:token_type_ids

Perhaps ONNX doesn't want token_type_ids to be passed, even if it is None? Then a solution might be:

inputs = {
    "input_ids": input_ids,
    "attention_mask": attention_mask,
}
if token_type_ids is not None:
    inputs["token_type_ids"] = token_type_ids
hidden_states = self.model_body(**inputs)
  • Tom Aarsen

tomaarsen avatar Nov 27 '23 14:11 tomaarsen

Don't run the CI yet, Im still making some tests

pedrogengo avatar Nov 27 '23 16:11 pedrogengo

@tomaarsen the issue with the tests is a little bit harder: for models that uses token_type_ids the tests run like a charm. However, for models that don't use token_type_ids the ONNX generates the graph without the token_type_ids as input. I tried your suggestion, but no success.

It is also an issue on transformers https://discuss.huggingface.co/t/deberta-v2-onnx-with-pipeline-does-not-work/35748

What I'm trying to do is to force the token_type_ids to appear on the graph, let's see if it works or not.

Do you have any other suggestion?

pedrogengo avatar Nov 27 '23 17:11 pedrogengo

Hmm, that is tricky. I'm not extremely familiar with the inner workings of ONNX, so I don't have great suggestions I'm afraid.

  • Tom Aarsen

tomaarsen avatar Nov 28 '23 09:11 tomaarsen

I will try more things today, but by now, I'm showing a message when the model doesn't use the token_type_ids, just to validate if the whole flow is working as expected.

pedrogengo avatar Nov 28 '23 10:11 pedrogengo

@tomaarsen I found a way to keep the same interface we are using today and force the token_type_ids to appear. It can't be an optional argument, because one time you have the ONNX graph defined, you must fill all the inputs used during the export, which means that we need to always pass the token_type_ids even if it is not used by the model.

On one hand it can look like: "Why keep an argument that is not used by some models?", but the answer is the generalization. Keeping the parameter for all the cases make it possible to have an export code that is general enough for all the model.

Let me know WDYT about the solution. Maybe in the future we can work to create better interfaces for the export, but I tried to keep this PR as simple as possible.

pedrogengo avatar Nov 28 '23 12:11 pedrogengo

@tomaarsen can we merge this? If yes, I will solve the conflict

pedrogengo avatar Dec 06 '23 16:12 pedrogengo

Apologies for not responding sooner, I've been a bit busy and ONNX wasn't very high on my TODO list. Do you suspect that this PR indeed fixes the reported discrepancies? E.g. does the script from #291 behave as expected now? The fix seems so odd to me, haha.

Also, models with a differentiable head are trained a bit differently in SetFit v1.0.0, i.e. no more freeze and unfreeze calls, and only calling trainer.train() once.

  • Tom Aarsen

tomaarsen avatar Dec 07 '23 14:12 tomaarsen

I will run the script with this branch, but on my tests I was seeing discrepancy between the results too, and after the fix it worked and returned the same scores.

Give me the weekend to see the code for v1.0.0 and I can answer here.

pedrogengo avatar Dec 07 '23 15:12 pedrogengo