optimum icon indicating copy to clipboard operation
optimum copied to clipboard

Fix CPU device parameter

Open mausch opened this issue 1 year ago • 2 comments

What does this PR do?

Fixes a crash in Trainer using CPU.

This parameter is expected to be of type torch.device not str.

Otherwise this crashes with:

../../.mamba/envs/score/lib/python3.7/site-packages/optimum/onnxruntime/trainer.py:1315: in _export
    self.model.to("cpu")
../../.mamba/envs/score/lib/python3.7/site-packages/optimum/onnxruntime/modeling_ort.py:122: in to
    provider = get_provider_for_device(self.device)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

device = 'cpu'

    def get_provider_for_device(device: torch.device) -> str:
        """
        Gets the ONNX Runtime provider associated with the PyTorch device (CPU/CUDA).
        """
>       return "CUDAExecutionProvider" if device.type.lower() == "cuda" else "CPUExecutionProvider"
E       AttributeError: 'str' object has no attribute 'type'

../../.mamba/envs/score/lib/python3.7/site-packages/optimum/onnxruntime/utils.py:165: AttributeError

Before submitting

  • [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • [ ] Did you make sure to update the documentation with your changes?
  • [ ] Did you write any new necessary tests?

mausch avatar Aug 10 '22 08:08 mausch

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

Hi @mausch,

The error message is a little bit ambiguous. Can you put the complete error message here or put a snippet for me to reproduce the error?

For the time being ORTTrainer takes only PyTorch model for instantiation. It seems to me that you have probably passed an ORTModel to it? Is that so?

JingyaHuang avatar Aug 12 '22 10:08 JingyaHuang