accelerate
accelerate copied to clipboard
accelerate autocast on mps device
System Info
accelerate==0.18.0
system==M2 macos 13.3.1
Information
- [ ] The official example scripts
- [X] My own modified scripts
Tasks
- [ ] One of the scripts in the examples/ folder of Accelerate or an officially supported
no_trainerscript in theexamplesfolder of thetransformersrepo (such asrun_no_trainer_glue.py) - [X] My own task or dataset (give details below)
Reproduction
Down below is the autocast implement. The default for fp16 is cuda, I wonder is it possible to adapt on macos mps device. Thanks very much.
def autocast(self):
"""
Will apply automatic mixed-precision inside the block inside this context manager, if it is enabled. Nothing
different will happen otherwise.
Example:
```python
>>> from accelerate import Accelerator
>>> accelerator = Accelerator(mixed_precision="fp16")
>>> with accelerator.autocast():
... train()
```
"""
if self.native_amp:
if self.mixed_precision == "fp16" and is_torch_version(">=", "1.10"):
autocast_context = torch.cuda.amp.autocast(dtype=torch.float16)
elif self.mixed_precision == "bf16":
if self.distributed_type in [DistributedType.NO, DistributedType.MULTI_CPU, DistributedType.MULTI_GPU]:
autocast_context = torch.autocast(dtype=torch.bfloat16, device_type=self.device.type)
else:
autocast_context = torch.cuda.amp.autocast()
autocast_context.__enter__()
yield
autocast_context.__exit__(*sys.exc_info())
else:
yield
Expected behavior
Omission
When PyTorch adds this feature, we'll make sure it's possible in Accelerate: https://github.com/pytorch/pytorch/issues/88415
Thanks for the reply, really looking forward to have it
Commenting to follow. It'll be great to finally have mixed precision training on M1!
I was trying to get this running,
What I did was
Create a fork of the latest pytorch repository and copy the changes manually from https://github.com/pytorch/pytorch/pull/99272 to my fork.
What I ended up with was.
https://github.com/sagargulabani/pytorch/commit/060ccae622a72711af0f8cce30b8617907ecd526
I built this pytorch locally from source to use in my conda environment.
However, now I am running into the following issue.
Traceback (most recent call last):
File "/Users/sagargulabani/.cache/huggingface/diffusers/examples/dreambooth/train_dreambooth_lora_sdxl.py", line 1964, in <module>
main(args)
File "/Users/sagargulabani/.cache/huggingface/diffusers/examples/dreambooth/train_dreambooth_lora_sdxl.py", line 1676, in main
model_pred = unet(
File "/Users/sagargulabani/.cache/huggingface/pytorch/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/Users/sagargulabani/.cache/huggingface/pytorch/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/opt/anaconda3/envs/hf/lib/python3.10/site-packages/accelerate/utils/operations.py", line 822, in forward
return model_forward(*args, **kwargs)
File "/opt/anaconda3/envs/hf/lib/python3.10/site-packages/accelerate/utils/operations.py", line 810, in __call__
return convert_to_fp32(self.model_forward(*args, **kwargs))
File "/Users/sagargulabani/.cache/huggingface/pytorch/torch/amp/autocast_mode.py", line 16, in decorate_autocast
return func(*args, **kwargs)
File "/Users/sagargulabani/.cache/huggingface/diffusers/src/diffusers/models/unets/unet_2d_condition.py", line 1216, in forward
sample, res_samples = downsample_block(
File "/Users/sagargulabani/.cache/huggingface/pytorch/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/Users/sagargulabani/.cache/huggingface/pytorch/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/Users/sagargulabani/.cache/huggingface/diffusers/src/diffusers/models/unets/unet_2d_blocks.py", line 1279, in forward
hidden_states = attn(
File "/Users/sagargulabani/.cache/huggingface/pytorch/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/Users/sagargulabani/.cache/huggingface/pytorch/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/Users/sagargulabani/.cache/huggingface/diffusers/src/diffusers/models/transformers/transformer_2d.py", line 397, in forward
hidden_states = block(
File "/Users/sagargulabani/.cache/huggingface/pytorch/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/Users/sagargulabani/.cache/huggingface/pytorch/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/Users/sagargulabani/.cache/huggingface/diffusers/src/diffusers/models/attention.py", line 366, in forward
attn_output = self.attn2(
File "/Users/sagargulabani/.cache/huggingface/pytorch/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/Users/sagargulabani/.cache/huggingface/pytorch/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/Users/sagargulabani/.cache/huggingface/diffusers/src/diffusers/models/attention_processor.py", line 522, in forward
return self.processor(
File "/Users/sagargulabani/.cache/huggingface/diffusers/src/diffusers/models/attention_processor.py", line 1279, in __call__
hidden_states = F.scaled_dot_product_attention(
RuntimeError: Expected query, key, and value to have the same dtype, but got query.dtype: float key.dtype: c10::Half and value.dtype: c10::Half instead.
Steps: 0%| | 0/500 [00:01<?, ?it/s]
libc++abi: terminating due to uncaught exception of type std::__1::system_error: recursive_mutex lock failed: Invalid argument
Traceback (most recent call last):
File "/opt/anaconda3/envs/hf/bin/accelerate", line 8, in <module>
sys.exit(main())
File "/opt/anaconda3/envs/hf/lib/python3.10/site-packages/accelerate/commands/accelerate_cli.py", line 46, in main
args.func(args)
File "/opt/anaconda3/envs/hf/lib/python3.10/site-packages/accelerate/commands/launch.py", line 1057, in launch_command
simple_launcher(args)
File "/opt/anaconda3/envs/hf/lib/python3.10/site-packages/accelerate/commands/launch.py", line 673, in simple_launcher
raise subprocess.CalledProcessError(returncode=process.returncode, cmd=cmd)
The main error seems to be this
RuntimeError: Expected query, key, and value to have the same dtype, but got query.dtype: float key.dtype: c10::Half and value.dtype: c10::Half instead.
Wondering what could be the fix for this ?
I have an M3 Max machine.
Thanks.
Any progress? Thanks