MPS Support
Copy of #790 but with proper rebase
@maximegmd is the real hero here.
:link: Helpful Links
:test_tube: See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1233
- :page_facing_up: Preview Python docs built from this PR
Note: Links to docs will display an error until the docs builds have been completed.
:x: 1 New Failure
As of commit 47be0652ccd6c0e9d28116a58a91e8c7999aa05e with merge base e10142016798cf84f2e5c638a985014384f400a7 ():
NEW FAILURE - The following job has failed:
-
Lint / lint (3.10) (gh)
Process completed with exit code 1.
This comment was automatically generated by Dr. CI and updates every 15 minutes.
Update hitting a blocker - can't get this running on M1 Mac with 32GB RAM, PyTorch 2.4.0:
I hit the following error:
/AppleInternal/Library/BuildRoots/91a344b1-f985-11ee-b563-fe8bc7981bff/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShaders/MPSCore/Types/MPSNDArray.mm:869: failed assertion `[MPSNDArray, initWithBuffer:descriptor:] Error: buffer is not large enough. Must be 393216 bytes
This happens right after the first loss is calculated.
(joe-torchtune) jrcummings@jrcummings-mbp joe-torchtune % tune run full_finetune_single_device --config gemma/2B_full_single_device device=mps optimizer._component_=torch.optim.SGD
W0726 16:52:51.106000 8571538432 torch/distributed/elastic/multiprocessing/redirects.py:28] NOTE: Redirects are currently not supported in Windows or MacOs.
INFO:torchtune.utils.logging:Running FullFinetuneRecipeSingleDevice with resolved config:
batch_size: 1
checkpointer:
_component_: torchtune.utils.FullModelHFCheckpointer
checkpoint_dir: /tmp/gemma-2b/
checkpoint_files:
- model-00001-of-00002.safetensors
- model-00002-of-00002.safetensors
model_type: GEMMA
output_dir: /tmp/gemma-2b
recipe_checkpoint: null
compile: false
dataset:
_component_: torchtune.datasets.alpaca_dataset
train_on_input: true
device: mps
dtype: bf16
enable_activation_checkpointing: true
epochs: 1
gradient_accumulation_steps: 1
log_every_n_steps: 1
log_peak_memory_stats: true
loss:
_component_: torch.nn.CrossEntropyLoss
max_steps_per_epoch: null
metric_logger:
_component_: torchtune.utils.metric_logging.DiskLogger
log_dir: /tmp/gemma-2b/logs
model:
_component_: torchtune.models.gemma.gemma_2b
optimizer:
_component_: torch.optim.SGD
lr: 2.0e-05
optimizer_in_bwd: false
output_dir: /tmp/gemma-2b/logs
resume_from_checkpoint: false
seed: null
shuffle: true
tokenizer:
_component_: torchtune.models.gemma.gemma_tokenizer
path: /tmp/gemma-2b/tokenizer.model
DEBUG:torchtune.utils.logging:Setting manual seed to local seed 2219507826. Local seed is seed + rank = 2219507826 + 0
Writing logs to /tmp/gemma-2b/logs/log_1722027171.txt
INFO:torchtune.utils.logging:Model is initialized with precision torch.bfloat16.
INFO:torchtune.utils.logging:Tokenizer is initialized from file.
INFO:torchtune.utils.logging:Optimizer is initialized.
INFO:torchtune.utils.logging:Loss is initialized.
INFO:torchtune.utils.logging:Dataset and Sampler are initialized.
0%| | 0/52002 [00:00<?, ?it/s]/Users/jrcummings/miniconda3/envs/joe-torchtune/lib/python3.11/site-packages/torch/utils/checkpoint.py:1399: FutureWarning: `torch.cpu.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cpu', args...)` instead.
with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context: # type: ignore[attr-defined]
1|1|Loss: 2.1554598808288574: 0%| | 1/52002 [00:03<52:31:28, 3.64s/it]/AppleInternal/Library/BuildRoots/91a344b1-f985-11ee-b563-fe8bc7981bff/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShaders/MPSCore/Types/MPSNDArray.mm:869: failed assertion `[MPSNDArray, initWithBuffer:descriptor:] Error: buffer is not large enough. Must be 393216 bytes
'
zsh: abort tune run full_finetune_single_device --config gemma/2B_full_single_device
(joe-torchtune) jrcummings@jrcummings-mbp joe-torchtune % /Users/jrcummings/miniconda3/envs/joe-torchtune/lib/python3.11/multiprocessing/resource_tracker.py:254: UserWarning: resource_tracker: There appear to be 1 leaked semaphore objects to clean up at shutdown
warnings.warn('resource_tracker: There appear to be %d '
Looks like it could be related to https://github.com/pytorch/pytorch/issues/87351? But that truly seems unlikely as this issue is from October 2022.
With PyTorch nightlies, I manage to run 4 steps before loss NaNs and buffer issue comes back:
1|4|Loss: nan: 0%| | 4/52002 [00:26<95:05:25, 6.58s/it]/AppleInternal/Library/BuildRoots/91a344b1-f985-11ee-b563-fe8bc7981bff/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShaders/MPSCore/Types/MPSNDArray.mm:869: failed assertion `[MPSNDArray, initWithBuffer:descriptor:] Error: buffer is not large enough. Must be 163840 bytes
'
zsh: abort tune run full_finetune_single_device --config gemma/2B_full_single_device
(joe-torchtune) jrcummings@jrcummings-mbp joe-torchtune % /Users/jrcummings/miniconda3/envs/joe-torchtune/lib/python3.11/multiprocessing/resource_tracker.py:254: UserWarning: resource_tracker: There appear to be 1 leaked semaphore objects to clean up at shutdown
warnings.warn('resource_tracker: There appear to be %d '
That's odd, I was able to train for hours with no issue. I can try a run tomorrow (64GB maybe this is why?)
I can test it on my machine but i need details on which commands to run (general instructions). We need this in general for this readme. Either the main readme.md or a mps.md file in the cookbooks directory. Once the PR has been updated with this i'll run through it on my M1 Max 64GB
- [ ] Compile isn't supported on mps - need to error or warn when compile=True and default to
aot_eager - [ ] Ensure PyTorch nightly version/check for version (at least up-to-date with fix for https://github.com/pytorch/pytorch/issues/130613#issuecomment-2226549088)
- [ ] Add memory tracking in
get_memory_stats - [x]
test_nf4_linearcrashes out due to importing bitsandbytes. Need to try catch? And mark the tests to skip if mps. - [x] maybe ensure latest macos? What version are you bugging out on @joecummings? I'm on 14.4.1
Also, several other tests fail due to (presumably) numerical representation differences. For example, PPO loss tests can't be replicated between non-MPS and MPS.
See which tests here
pytest tests -k "not distributed"
============================================================================== short test summary info ===============================================================================
FAILED tests/torchtune/models/llama3_1/test_position_embeddings.py::TestLlama3ScaledRoPE::test_forward - AssertionError: actual: -83.30372619628906, expected: -83.15229797363281
FAILED tests/torchtune/models/llama3_1/test_position_embeddings.py::TestLlama3ScaledRoPE::test_forward_with_curr_pos - AssertionError: actual: -83.30372619628906, expected: -83.15229797363281
FAILED tests/torchtune/models/llama3_1/test_position_embeddings.py::TestLlama3ScaledRoPE::test_forward_with_2d_pos_ids - AssertionError: actual: -83.30372619628906, expected: -83.15229797363281
FAILED tests/torchtune/modules/test_position_embeddings.py::TestRotaryPositionEmbedding::test_forward - AssertionError: actual: 2165.59619140625, expected: 2165.705322265625
FAILED tests/torchtune/modules/test_position_embeddings.py::TestRotaryPositionEmbedding::test_forward_with_curr_pos - AssertionError: actual: 2165.59619140625, expected: 2165.705322265625
FAILED tests/torchtune/modules/test_position_embeddings.py::TestRotaryPositionEmbedding::test_forward_with_packed_pos - AssertionError: actual: 2165.59619140625, expected: 2165.705322265625
FAILED tests/torchtune/modules/test_position_embeddings.py::TestPhi3RotaryPositionalEmbeddings::test_forward - AssertionError: actual: -381.06915283203125, expected: -381.06201171875
FAILED tests/torchtune/modules/test_transformer_decoder.py::TestTransformerDecoder::test_forward - AssertionError: actual: 20.48008918762207, expected: 20.479999542236328
FAILED tests/torchtune/utils/test_generation.py::TestTextGenerate::test_batched_generate - assert [[2, 3, 4, 5,...5, 6, 7, ...]] == [[2, 3, 4, 5,...5, 6, 7, ...]]
FAILED tests/torchtune/utils/test_generation.py::TestTextGenerate::test_stop_tokens - assert [[2, 3, 4, 5, 6, 7, ...]] == [[2, 3, 4, 5, 6, 7, ...]]
FAILED tests/torchtune/utils/test_generation.py::TestTextGenerate::test_stop_tokens_batched - assert [[2, 3, 4, 5,...5, 6, 7, ...]] == [[2, 3, 4, 5,...5, 6, 7, ...]]
FAILED tests/torchtune/utils/test_generation.py::TestTextGenerate::test_stop_tokens_batched_uneven_stopping - assert [[2, 3, 4, 5,...5, 6, 7, ...]] == [[2, 3, 4, 5,...5, 6, 7, ...]]
FAILED tests/torchtune/utils/test_generation.py::TestTextGenerate::test_stop_tokens_batched_uneven_stoppin_with_diff_pad_id - assert [[2, 3, 4, 5,...5, 6, 7, ...]] == [[2, 3, 4, 5,...5, 6, 7, ...]]
======================================================= 13 failed, 297 passed, 23 skipped, 27 deselected, 2 warnings in 24.76s =======================================================
@joecummings things I usually add locally. Happy to contribute any of these here to help, or consider that they may be out-of-scope atm. FWIW I don't think we need another whole config file for this - just show the user to speficy device=mps in the readme +/ docs.
Any update on this? Its being quite a time that support for MPS is underway but not concluding.
Closed due to #1706. For anyone following along, feel free to try out torchtune on MPS and let us know how you get along!