composer
composer copied to clipboard
Device mismatch during evaluation when training on mps
I ran into an issue trying to train flan-t5 on an M1 using torchmetrics. Training metrics worked fine, but I got the following stacktrace when calculating evaluation metrics:
...
File "/Users/erosenthal/.pyenv/versions/pynlp/lib/python3.9/site-packages/composer/trainer/trainer.py", line 1804, in fit
self._train_loop()
File "/Users/erosenthal/.pyenv/versions/pynlp/lib/python3.9/site-packages/composer/trainer/trainer.py", line 2032, in _train_loop
self._run_evaluators(Event.BATCH_END)
File "/Users/erosenthal/.pyenv/versions/pynlp/lib/python3.9/site-packages/composer/trainer/trainer.py", line 2117, in _run_evaluators
self._eval_loop(
File "/Users/erosenthal/.pyenv/versions/pynlp/lib/python3.9/site-packages/composer/trainer/trainer.py", line 2833, in _eval_loop
self._original_model.update_metric(
File "/Users/erosenthal/.pyenv/versions/pynlp/lib/python3.9/site-packages/composer/models/huggingface.py", line 438, in update_metric
metric.update(outputs, self.labels) # pyright: ignore [reportGeneralTypeIssues]
File "/Users/erosenthal/.pyenv/versions/pynlp/lib/python3.9/site-packages/torchmetrics/metric.py", line 400, in wrapped_func
raise err
File "/Users/erosenthal/.pyenv/versions/pynlp/lib/python3.9/site-packages/torchmetrics/metric.py", line 390, in wrapped_func
update(*args, **kwargs)
File "/Users/erosenthal/.pyenv/versions/pynlp/lib/python3.9/site-packages/composer/metrics/nlp.py", line 111, in update
losses = self.loss_fn(logits, target)
File "/Users/erosenthal/.pyenv/versions/pynlp/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/Users/erosenthal/.pyenv/versions/pynlp/lib/python3.9/site-packages/torch/nn/modules/loss.py", line 1174, in forward
return F.cross_entropy(input, target, weight=self.weight,
File "/Users/erosenthal/.pyenv/versions/pynlp/lib/python3.9/site-packages/torch/nn/functional.py", line 3029, in cross_entropy
return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index, label_smoothing)
RuntimeError: Placeholder storage has not been allocated on MPS device!
I believe the issue is related to the following snippet of code:
https://github.com/mosaicml/composer/blob/ff59e862b92a7a1e62f72b57e36f528eb2c4bdfa/composer/trainer/trainer.py#L2846-L2858
The outputs
tensor is explicitly moved to cpu
if it's on mps
, but the batch
tensor is not. Hence, you inevitably have a device mismatch when updating metrics. AFAICT, outputs
are not explicitly moved to cpu
when they're on mps
when updating training metrics which is why I only saw this bug during evaluation.
If there really are numerical errors with torchmetrics
on mps
, then training metrics probably ought to be calculated on cpu
in order to bring parity to the training and eval calculations. Additionally, the batch
tensor will need to be moved to cpu
.
I observed the same problem. My hack was to do metric.update(outputs.to(targets.device), targets)
in the
update_metric
which is not ideal.
Thanks for the bug report @antoinebrl and @erosenthal-square . We'll take a look and try a fix. torchmetrics
was also recently updated to 1.0, and they may have fixed their numerical issues.
Thanks @hanlint !
@antoinebrl Yup, I did the same thing! Relatedly, I also had to convert outputs
to full precision within update_metric
if I tried to train with mixed precision.
@hanlint, Any update on this issue? The transfer back to CPU is a major slowdown.
https://github.com/mosaicml/composer/pull/3105
I believe this community PR addresses the issue -- sorry for the delay!
With respect to skipping it, I think this is an issue with MPS reliability and unfortunately outside our control :(
If it's avoidable, we can remove it. Unfortunately I don't have a great test bed to debug these numerical issues on macs
@hanlint I saw that in https://github.com/mosaicml/composer/pull/1405 you reported an issue on M1 devices with torchmetrics
. I am unfortunately unable to reproduce any error and I could only find this issue that is supposedly fixed https://github.com/pytorch/pytorch/issues/78168 .
I would be interested to know if you have any snippets to share to reproduce the issue? While I made a fix for this issue I feel like the code transporting metrics back to cpu in case of an MPS device could be dropped. What @antoinebrl ended up doing is actually transporting the outputs back to the MPS device (since the batch stayed there) and it was working so I want to confirm we can drop this change.
@hyenal I believe it was based on https://github.com/Lightning-AI/torchmetrics/issues/1727
I just reran the examples provided and have encountered the same error :(
Given this, I assume most users prefer having correct results even if it is slower