transformers
transformers copied to clipboard
Fix ORTTrainer failure on DeBERTa(base/v2/sew_d) fp16 training
What does this PR do?
Context
It was reported in optimum https://github.com/huggingface/optimum/issues/305 that the training on DeBERTa with optimum.onnxruntime.ORTTrainer is broken. After investigation, the break comes from two causes:
- At that time
XDropOutdidn't have a symbolic function. And it has been implemented by @garymm in https://github.com/huggingface/transformers/pull/17502 and has been merged to the main of transformers. - The implementation of DeBERTa have some numpy/math operations that led to incorrect export. This will be fixed in https://github.com/huggingface/transformers/pull/18272.
However with those two fixes, the fp32 training will work, but the mixed-precision training will fail due to mismatched inputs dtype for some Matmul nodes. In https://github.com/huggingface/transformers/pull/18272, some sqrt results are cast to fp32, and they need to be re-casted to fp16 before Matmul ops, and this PR is supposed to add the re-cast part.
Fixes #https://github.com/huggingface/optimum/issues/305
Who can review?
@LysandreJik @patrickvonplaten @lewtun
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.
close as it turned to be too messy even after rebasing.