transformers icon indicating copy to clipboard operation
transformers copied to clipboard

Fix ORTTrainer failure on DeBERTa(base/v2/sew_d) fp16 training

Open JingyaHuang opened this issue 3 years ago • 1 comments

What does this PR do?

Context

It was reported in optimum https://github.com/huggingface/optimum/issues/305 that the training on DeBERTa with optimum.onnxruntime.ORTTrainer is broken. After investigation, the break comes from two causes:

  • At that time XDropOut didn't have a symbolic function. And it has been implemented by @garymm in https://github.com/huggingface/transformers/pull/17502 and has been merged to the main of transformers.
  • The implementation of DeBERTa have some numpy/math operations that led to incorrect export. This will be fixed in https://github.com/huggingface/transformers/pull/18272.

However with those two fixes, the fp32 training will work, but the mixed-precision training will fail due to mismatched inputs dtype for some Matmul nodes. In https://github.com/huggingface/transformers/pull/18272, some sqrt results are cast to fp32, and they need to be re-casted to fp16 before Matmul ops, and this PR is supposed to add the re-cast part.

Fixes #https://github.com/huggingface/optimum/issues/305

Who can review?

@LysandreJik @patrickvonplaten @lewtun

JingyaHuang avatar Aug 08 '22 15:08 JingyaHuang

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

close as it turned to be too messy even after rebasing.

JingyaHuang avatar Aug 11 '22 15:08 JingyaHuang