optimum
optimum copied to clipboard
Fix CPU device parameter
What does this PR do?
Fixes a crash in Trainer using CPU.
This parameter is expected to be of type torch.device
not str
.
Otherwise this crashes with:
../../.mamba/envs/score/lib/python3.7/site-packages/optimum/onnxruntime/trainer.py:1315: in _export
self.model.to("cpu")
../../.mamba/envs/score/lib/python3.7/site-packages/optimum/onnxruntime/modeling_ort.py:122: in to
provider = get_provider_for_device(self.device)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
device = 'cpu'
def get_provider_for_device(device: torch.device) -> str:
"""
Gets the ONNX Runtime provider associated with the PyTorch device (CPU/CUDA).
"""
> return "CUDAExecutionProvider" if device.type.lower() == "cuda" else "CPUExecutionProvider"
E AttributeError: 'str' object has no attribute 'type'
../../.mamba/envs/score/lib/python3.7/site-packages/optimum/onnxruntime/utils.py:165: AttributeError
Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
- [ ] Did you make sure to update the documentation with your changes?
- [ ] Did you write any new necessary tests?
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.
Hi @mausch,
The error message is a little bit ambiguous. Can you put the complete error message here or put a snippet for me to reproduce the error?
For the time being ORTTrainer
takes only PyTorch model for instantiation. It seems to me that you have probably passed an ORTModel to it? Is that so?