torchtune icon indicating copy to clipboard operation
torchtune copied to clipboard

MPS Support

Open joecummings opened this issue 1 year ago • 6 comments

Copy of #790 but with proper rebase

@maximegmd is the real hero here.

joecummings avatar Jul 26 '24 19:07 joecummings

:link: Helpful Links

:test_tube: See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1233

Note: Links to docs will display an error until the docs builds have been completed.

:x: 1 New Failure

As of commit 47be0652ccd6c0e9d28116a58a91e8c7999aa05e with merge base e10142016798cf84f2e5c638a985014384f400a7 (image):

NEW FAILURE - The following job has failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

pytorch-bot[bot] avatar Jul 26 '24 19:07 pytorch-bot[bot]

Update hitting a blocker - can't get this running on M1 Mac with 32GB RAM, PyTorch 2.4.0:

I hit the following error:

/AppleInternal/Library/BuildRoots/91a344b1-f985-11ee-b563-fe8bc7981bff/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShaders/MPSCore/Types/MPSNDArray.mm:869: failed assertion `[MPSNDArray, initWithBuffer:descriptor:] Error: buffer is not large enough. Must be 393216 bytes

This happens right after the first loss is calculated.

(joe-torchtune) jrcummings@jrcummings-mbp joe-torchtune % tune run full_finetune_single_device --config gemma/2B_full_single_device device=mps optimizer._component_=torch.optim.SGD
W0726 16:52:51.106000 8571538432 torch/distributed/elastic/multiprocessing/redirects.py:28] NOTE: Redirects are currently not supported in Windows or MacOs.
INFO:torchtune.utils.logging:Running FullFinetuneRecipeSingleDevice with resolved config:

batch_size: 1
checkpointer:
  _component_: torchtune.utils.FullModelHFCheckpointer
  checkpoint_dir: /tmp/gemma-2b/
  checkpoint_files:
  - model-00001-of-00002.safetensors
  - model-00002-of-00002.safetensors
  model_type: GEMMA
  output_dir: /tmp/gemma-2b
  recipe_checkpoint: null
compile: false
dataset:
  _component_: torchtune.datasets.alpaca_dataset
  train_on_input: true
device: mps
dtype: bf16
enable_activation_checkpointing: true
epochs: 1
gradient_accumulation_steps: 1
log_every_n_steps: 1
log_peak_memory_stats: true
loss:
  _component_: torch.nn.CrossEntropyLoss
max_steps_per_epoch: null
metric_logger:
  _component_: torchtune.utils.metric_logging.DiskLogger
  log_dir: /tmp/gemma-2b/logs
model:
  _component_: torchtune.models.gemma.gemma_2b
optimizer:
  _component_: torch.optim.SGD
  lr: 2.0e-05
optimizer_in_bwd: false
output_dir: /tmp/gemma-2b/logs
resume_from_checkpoint: false
seed: null
shuffle: true
tokenizer:
  _component_: torchtune.models.gemma.gemma_tokenizer
  path: /tmp/gemma-2b/tokenizer.model

DEBUG:torchtune.utils.logging:Setting manual seed to local seed 2219507826. Local seed is seed + rank = 2219507826 + 0
Writing logs to /tmp/gemma-2b/logs/log_1722027171.txt
INFO:torchtune.utils.logging:Model is initialized with precision torch.bfloat16.
INFO:torchtune.utils.logging:Tokenizer is initialized from file.
INFO:torchtune.utils.logging:Optimizer is initialized.
INFO:torchtune.utils.logging:Loss is initialized.
INFO:torchtune.utils.logging:Dataset and Sampler are initialized.
  0%|                                                                                                                                                                                                                      | 0/52002 [00:00<?, ?it/s]/Users/jrcummings/miniconda3/envs/joe-torchtune/lib/python3.11/site-packages/torch/utils/checkpoint.py:1399: FutureWarning: `torch.cpu.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cpu', args...)` instead.
  with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context:  # type: ignore[attr-defined]
1|1|Loss: 2.1554598808288574:   0%|                                                                                                                                                                             | 1/52002 [00:03<52:31:28,  3.64s/it]/AppleInternal/Library/BuildRoots/91a344b1-f985-11ee-b563-fe8bc7981bff/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShaders/MPSCore/Types/MPSNDArray.mm:869: failed assertion `[MPSNDArray, initWithBuffer:descriptor:] Error: buffer is not large enough. Must be 393216 bytes
'
zsh: abort      tune run full_finetune_single_device --config gemma/2B_full_single_device
(joe-torchtune) jrcummings@jrcummings-mbp joe-torchtune % /Users/jrcummings/miniconda3/envs/joe-torchtune/lib/python3.11/multiprocessing/resource_tracker.py:254: UserWarning: resource_tracker: There appear to be 1 leaked semaphore objects to clean up at shutdown
  warnings.warn('resource_tracker: There appear to be %d '

Looks like it could be related to https://github.com/pytorch/pytorch/issues/87351? But that truly seems unlikely as this issue is from October 2022.

joecummings avatar Jul 26 '24 20:07 joecummings

With PyTorch nightlies, I manage to run 4 steps before loss NaNs and buffer issue comes back:

1|4|Loss: nan:   0%|                                                                                                                                                                                            | 4/52002 [00:26<95:05:25,  6.58s/it]/AppleInternal/Library/BuildRoots/91a344b1-f985-11ee-b563-fe8bc7981bff/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShaders/MPSCore/Types/MPSNDArray.mm:869: failed assertion `[MPSNDArray, initWithBuffer:descriptor:] Error: buffer is not large enough. Must be 163840 bytes
'
zsh: abort      tune run full_finetune_single_device --config gemma/2B_full_single_device
(joe-torchtune) jrcummings@jrcummings-mbp joe-torchtune % /Users/jrcummings/miniconda3/envs/joe-torchtune/lib/python3.11/multiprocessing/resource_tracker.py:254: UserWarning: resource_tracker: There appear to be 1 leaked semaphore objects to clean up at shutdown
  warnings.warn('resource_tracker: There appear to be %d '

joecummings avatar Jul 26 '24 21:07 joecummings

That's odd, I was able to train for hours with no issue. I can try a run tomorrow (64GB maybe this is why?)

maximegmd avatar Jul 26 '24 21:07 maximegmd

I can test it on my machine but i need details on which commands to run (general instructions). We need this in general for this readme. Either the main readme.md or a mps.md file in the cookbooks directory. Once the PR has been updated with this i'll run through it on my M1 Max 64GB

byjlw avatar Jul 29 '24 16:07 byjlw

  • [ ] Compile isn't supported on mps - need to error or warn when compile=True and default to aot_eager
  • [ ] Ensure PyTorch nightly version/check for version (at least up-to-date with fix for https://github.com/pytorch/pytorch/issues/130613#issuecomment-2226549088)
  • [ ] Add memory tracking in get_memory_stats
  • [x] test_nf4_linear crashes out due to importing bitsandbytes. Need to try catch? And mark the tests to skip if mps.
  • [x] maybe ensure latest macos? What version are you bugging out on @joecummings? I'm on 14.4.1

Also, several other tests fail due to (presumably) numerical representation differences. For example, PPO loss tests can't be replicated between non-MPS and MPS.

See which tests here
pytest tests -k "not distributed"
============================================================================== short test summary info ===============================================================================
FAILED tests/torchtune/models/llama3_1/test_position_embeddings.py::TestLlama3ScaledRoPE::test_forward - AssertionError: actual: -83.30372619628906, expected: -83.15229797363281
FAILED tests/torchtune/models/llama3_1/test_position_embeddings.py::TestLlama3ScaledRoPE::test_forward_with_curr_pos - AssertionError: actual: -83.30372619628906, expected: -83.15229797363281
FAILED tests/torchtune/models/llama3_1/test_position_embeddings.py::TestLlama3ScaledRoPE::test_forward_with_2d_pos_ids - AssertionError: actual: -83.30372619628906, expected: -83.15229797363281
FAILED tests/torchtune/modules/test_position_embeddings.py::TestRotaryPositionEmbedding::test_forward - AssertionError: actual: 2165.59619140625, expected: 2165.705322265625
FAILED tests/torchtune/modules/test_position_embeddings.py::TestRotaryPositionEmbedding::test_forward_with_curr_pos - AssertionError: actual: 2165.59619140625, expected: 2165.705322265625
FAILED tests/torchtune/modules/test_position_embeddings.py::TestRotaryPositionEmbedding::test_forward_with_packed_pos - AssertionError: actual: 2165.59619140625, expected: 2165.705322265625
FAILED tests/torchtune/modules/test_position_embeddings.py::TestPhi3RotaryPositionalEmbeddings::test_forward - AssertionError: actual: -381.06915283203125, expected: -381.06201171875
FAILED tests/torchtune/modules/test_transformer_decoder.py::TestTransformerDecoder::test_forward - AssertionError: actual: 20.48008918762207, expected: 20.479999542236328
FAILED tests/torchtune/utils/test_generation.py::TestTextGenerate::test_batched_generate - assert [[2, 3, 4, 5,...5, 6, 7, ...]] == [[2, 3, 4, 5,...5, 6, 7, ...]]
FAILED tests/torchtune/utils/test_generation.py::TestTextGenerate::test_stop_tokens - assert [[2, 3, 4, 5, 6, 7, ...]] == [[2, 3, 4, 5, 6, 7, ...]]
FAILED tests/torchtune/utils/test_generation.py::TestTextGenerate::test_stop_tokens_batched - assert [[2, 3, 4, 5,...5, 6, 7, ...]] == [[2, 3, 4, 5,...5, 6, 7, ...]]
FAILED tests/torchtune/utils/test_generation.py::TestTextGenerate::test_stop_tokens_batched_uneven_stopping - assert [[2, 3, 4, 5,...5, 6, 7, ...]] == [[2, 3, 4, 5,...5, 6, 7, ...]]
FAILED tests/torchtune/utils/test_generation.py::TestTextGenerate::test_stop_tokens_batched_uneven_stoppin_with_diff_pad_id - assert [[2, 3, 4, 5,...5, 6, 7, ...]] == [[2, 3, 4, 5,...5, 6, 7, ...]]
======================================================= 13 failed, 297 passed, 23 skipped, 27 deselected, 2 warnings in 24.76s =======================================================

@joecummings things I usually add locally. Happy to contribute any of these here to help, or consider that they may be out-of-scope atm. FWIW I don't think we need another whole config file for this - just show the user to speficy device=mps in the readme +/ docs.

salmanmohammadi avatar Aug 05 '24 11:08 salmanmohammadi

Any update on this? Its being quite a time that support for MPS is underway but not concluding.

bhupesh-sf avatar Sep 06 '24 16:09 bhupesh-sf

Closed due to #1706. For anyone following along, feel free to try out torchtune on MPS and let us know how you get along!

salmanmohammadi avatar Sep 28 '24 09:09 salmanmohammadi