torchtune
torchtune copied to clipboard
[llama4] use grouped_mm in moe for sm90
Stack from ghstack (oldest at bottom):
- #2756
- -> #2755
Enablement copy pasted from torchtitan. There is only one difference - output of grouped_mm is unintialized after offsets[-1] (result of at::empty)
out = _grouped_mm(h, w2, offs=offsets)
out[offsets[-1] :].zero_()
scatter_add does not have ignore indices and default 0 in token_indices will make it copied to the 0 index. To prevent copying unitialized data - zero_() after grouped_mm for now.
grouped_mm in eager: eager tps_avg: 3.2182 -> 4.44
grouped_mm in compile: compile tps_avg: 4.33463 -> 8.70872
Observing NaN loss in compile :( with/without grouped_mm - will be debugging this separately
no_compile + no grouped_mm:
BASELINE:
Step 1 | loss:4.149928092956543 lr:2e-05 tokens_per_second_per_gpu:2.203409433364868 peak_memory_active:28.82542610168457 peak_memory_alloc:28.82542610168457 peak_memory_reserved:40.197265625
Step 2 | loss:2.3132553100585938 lr:2e-05 tokens_per_second_per_gpu:2.492661952972412 peak_memory_active:28.804787158966064 peak_memory_alloc:28.804787158966064 peak_memory_reserved:40.251953125
Step 3 | loss:1.6078141927719116 lr:2e-05 tokens_per_second_per_gpu:3.8224945068359375 peak_memory_active:28.853728771209717 peak_memory_alloc:28.853728771209717 peak_memory_reserved:40.251953125
Step 4 | loss:1.4519065618515015 lr:2e-05 tokens_per_second_per_gpu:3.001920700073242 peak_memory_active:28.84491729736328 peak_memory_alloc:28.84491729736328 peak_memory_reserved:40.251953125
Step 5 | loss:1.2131776809692383 lr:2e-05 tokens_per_second_per_gpu:3.3134357929229736 peak_memory_active:28.804935932159424 peak_memory_alloc:28.804935932159424 peak_memory_reserved:40.251953125
Step 6 | loss:1.411360263824463 lr:2e-05 tokens_per_second_per_gpu:4.279047966003418 peak_memory_active:28.83890438079834 peak_memory_alloc:28.83890438079834 peak_memory_reserved:40.251953125
Step 7 | loss:1.1743241548538208 lr:2e-05 tokens_per_second_per_gpu:3.129912853240967 peak_memory_active:28.785582065582275 peak_memory_alloc:28.785582065582275 peak_memory_reserved:40.251953125
Step 8 | loss:1.3950272798538208 lr:2e-05 tokens_per_second_per_gpu:3.4216036796569824 peak_memory_active:28.967872619628906 peak_memory_alloc:28.967872619628906 peak_memory_reserved:40.251953125
Step 9 | loss:1.2500101327896118 lr:2e-05 tokens_per_second_per_gpu:3.410867214202881 peak_memory_active:28.854501247406006 peak_memory_alloc:28.854501247406006 peak_memory_reserved:40.251953125
Step 10 | loss:1.1264036893844604 lr:2e-05 tokens_per_second_per_gpu:3.1065995693206787 peak_memory_active:28.782616138458252 peak_memory_alloc:28.782616138458252 peak_memory_reserved:40.251953125
/tmp/torchtune/llama4_17Bx16E/full/logs
└─ $ tune_logs_tps log_1747910831.txt
tps_avg: 3.2182 n= 10
peak_memory_alloc max 28.97
peak_memory_reser max 40.25
no_compile + grouped_mm:
______
EAGER_GROUPED_MM:
└─ $ cat log_1747915295.txt
Step 1 | loss:4.163714408874512 lr:2e-05 tokens_per_second_per_gpu:2.171847343444824 peak_memory_active:27.444169998168945 peak_memory_alloc:27.444169998168945 peak_memory_reserved:38.9609375
Step 2 | loss:2.283238172531128 lr:2e-05 tokens_per_second_per_gpu:3.158111810684204 peak_memory_active:27.428916931152344 peak_memory_alloc:27.428916931152344 peak_memory_reserved:39.017578125
Step 3 | loss:1.5704680681228638 lr:2e-05 tokens_per_second_per_gpu:4.869690418243408 peak_memory_active:27.469550132751465 peak_memory_alloc:27.469550132751465 peak_memory_reserved:39.017578125
Step 4 | loss:1.4485288858413696 lr:2e-05 tokens_per_second_per_gpu:4.026949405670166 peak_memory_active:27.46179723739624 peak_memory_alloc:27.46179723739624 peak_memory_reserved:39.017578125
Step 5 | loss:1.1918020248413086 lr:2e-05 tokens_per_second_per_gpu:4.502264976501465 peak_memory_active:27.428401947021484 peak_memory_alloc:27.428401947021484 peak_memory_reserved:39.017578125
Step 6 | loss:1.3454418182373047 lr:2e-05 tokens_per_second_per_gpu:5.665891170501709 peak_memory_active:27.456215858459473 peak_memory_alloc:27.456215858459473 peak_memory_reserved:39.017578125
Step 7 | loss:1.185951590538025 lr:2e-05 tokens_per_second_per_gpu:4.356126308441162 peak_memory_active:27.412775993347168 peak_memory_alloc:27.412775993347168 peak_memory_reserved:39.017578125
Step 8 | loss:1.3995944261550903 lr:2e-05 tokens_per_second_per_gpu:4.84677267074585 peak_memory_active:27.562548637390137 peak_memory_alloc:27.562548637390137 peak_memory_reserved:39.017578125
Step 9 | loss:1.2516542673110962 lr:2e-05 tokens_per_second_per_gpu:4.785239219665527 peak_memory_active:27.469550132751465 peak_memory_alloc:27.469550132751465 peak_memory_reserved:39.017578125
Step 10 | loss:1.1245404481887817 lr:2e-05 tokens_per_second_per_gpu:4.080092906951904 peak_memory_active:27.41029644012451 peak_memory_alloc:27.41029644012451 peak_memory_reserved:39.017578125
Step 11 | loss:1.2505868673324585 lr:2e-05 tokens_per_second_per_gpu:4.96058988571167 peak_memory_active:27.416271209716797 peak_memory_alloc:27.416271209716797 peak_memory_reserved:39.0234375
Step 12 | loss:1.0976852178573608 lr:2e-05 tokens_per_second_per_gpu:4.555960655212402 peak_memory_active:27.415279388427734 peak_memory_alloc:27.415279388427734 peak_memory_reserved:39.0234375
Step 13 | loss:1.176978349685669 lr:2e-05 tokens_per_second_per_gpu:5.802102565765381 peak_memory_active:27.413253784179688 peak_memory_alloc:27.413253784179688 peak_memory_reserved:39.0234375
/tmp/torchtune/llama4_17Bx16E/full/logs
└─ $ tune_logs_tps log_1747915295.txt
tps_avg: 4.44474 n= 13
peak_memory_alloc max 27.56
peak_memory_reser max 39.02
compile + no_grouped_mm:
COMPILED BASELINE:
└─ $ cat log_1747915995.txt
Step 1 | loss:4.1595354080200195 lr:2e-05 tokens_per_second_per_gpu:0.42956066131591797 peak_memory_active:28.81977891921997 peak_memory_alloc:28.81977891921997 peak_memory_reserved:40.236328125
Step 2 | loss:2.2778103351593018 lr:2e-05 tokens_per_second_per_gpu:2.819465160369873 peak_memory_active:28.79981756210327 peak_memory_alloc:28.79981756210327 peak_memory_reserved:40.2890625
Step 3 | loss:1.786258578300476 lr:2e-05 tokens_per_second_per_gpu:5.622256755828857 peak_memory_active:28.846065998077393 peak_memory_alloc:28.846065998077393 peak_memory_reserved:40.2890625
Step 4 | loss:1.600595474243164 lr:2e-05 tokens_per_second_per_gpu:2.996225357055664 peak_memory_active:28.837817668914795 peak_memory_alloc:28.837817668914795 peak_memory_reserved:40.2890625
Step 5 | loss:1.627750039100647 lr:2e-05 tokens_per_second_per_gpu:6.902056694030762 peak_memory_active:28.79964590072632 peak_memory_alloc:28.79964590072632 peak_memory_reserved:40.2890625
Step 6 | loss:1.6751512289047241 lr:2e-05 tokens_per_second_per_gpu:3.8652405738830566 peak_memory_active:28.833858966827393 peak_memory_alloc:28.833858966827393 peak_memory_reserved:40.2890625
Step 7 | loss:1.5953153371810913 lr:2e-05 tokens_per_second_per_gpu:2.8824896812438965 peak_memory_active:28.781311511993408 peak_memory_alloc:28.781311511993408 peak_memory_reserved:40.2890625
Step 8 | loss:1.6390446424484253 lr:2e-05 tokens_per_second_per_gpu:4.305164813995361 peak_memory_active:28.948315620422363 peak_memory_alloc:28.948315620422363 peak_memory_reserved:40.2890625
Step 9 | loss:1.53915536403656 lr:2e-05 tokens_per_second_per_gpu:4.096757888793945 peak_memory_active:28.84673547744751 peak_memory_alloc:28.84673547744751 peak_memory_reserved:40.2890625
Step 10 | loss:1.4715595245361328 lr:2e-05 tokens_per_second_per_gpu:6.167354106903076 peak_memory_active:28.778648853302002 peak_memory_alloc:28.778648853302002 peak_memory_reserved:40.2890625
Step 11 | loss:nan lr:2e-05 tokens_per_second_per_gpu:4.351371765136719 peak_memory_active:28.787389278411865 peak_memory_alloc:28.787389278411865 peak_memory_reserved:40.2890625
/tmp/torchtune/llama4_17Bx16E/full/logs
└─ $ tune_logs_tps log_1747915995.txt
tps_avg: 4.33463 n= 12
peak_memory_alloc max 28.95
peak_memory_reser max 40.29
COMPILED + GROUPED_MM
└─ $ cat log_1747916676.txt
Step 1 | loss:4.163997650146484 lr:2e-05 tokens_per_second_per_gpu:0.9783019423484802 peak_memory_active:27.437111377716064 peak_memory_alloc:27.437111377716064 peak_memory_reserved:38.978515625
Step 2 | loss:2.2560360431671143 lr:2e-05 tokens_per_second_per_gpu:2.861786365509033 peak_memory_active:27.42238759994507 peak_memory_alloc:27.42238759994507 peak_memory_reserved:39.03515625
Step 3 | loss:1.8066692352294922 lr:2e-05 tokens_per_second_per_gpu:7.731247901916504 peak_memory_active:27.45892906188965 peak_memory_alloc:27.45892906188965 peak_memory_reserved:39.03515625
Step 4 | loss:1.7610063552856445 lr:2e-05 tokens_per_second_per_gpu:6.10059928894043 peak_memory_active:27.452006340026855 peak_memory_alloc:27.452006340026855 peak_memory_reserved:39.03515625
Step 5 | loss:1.4565016031265259 lr:2e-05 tokens_per_second_per_gpu:10.525087356567383 peak_memory_active:27.42187261581421 peak_memory_alloc:27.42187261581421 peak_memory_reserved:39.03515625
Step 6 | loss:nan lr:2e-05 tokens_per_second_per_gpu:8.103992462158203 peak_memory_active:27.448771476745605 peak_memory_alloc:27.448771476745605 peak_memory_reserved:39.03515625
Step 7 | loss:nan lr:2e-05 tokens_per_second_per_gpu:7.136734962463379 peak_memory_active:27.407429218292236 peak_memory_alloc:27.407429218292236 peak_memory_reserved:39.03515625
Step 8 | loss:nan lr:2e-05 tokens_per_second_per_gpu:11.029473304748535 peak_memory_active:27.543721675872803 peak_memory_alloc:27.543721675872803 peak_memory_reserved:39.03515625
Step 9 | loss:nan lr:2e-05 tokens_per_second_per_gpu:8.074372291564941 peak_memory_active:27.45892906188965 peak_memory_alloc:27.45892906188965 peak_memory_reserved:39.03515625
Step 10 | loss:nan lr:2e-05 tokens_per_second_per_gpu:8.008955955505371 peak_memory_active:27.40485429763794 peak_memory_alloc:27.40485429763794 peak_memory_reserved:39.03515625
Step 11 | loss:nan lr:2e-05 tokens_per_second_per_gpu:9.958040237426758 peak_memory_active:27.411057949066162 peak_memory_alloc:27.411057949066162 peak_memory_reserved:39.03515625
Step 12 | loss:nan lr:2e-05 tokens_per_second_per_gpu:9.930673599243164 peak_memory_active:27.410027980804443 peak_memory_alloc:27.410027980804443 peak_memory_reserved:39.03515625
Step 13 | loss:nan lr:2e-05 tokens_per_second_per_gpu:13.09176254272461 peak_memory_active:27.406914234161377 peak_memory_alloc:27.406914234161377 peak_memory_reserved:39.03515625
Step 14 | loss:nan lr:2e-05 tokens_per_second_per_gpu:1.858479619026184 peak_memory_active:27.490173816680908 peak_memory_alloc:27.490173816680908 peak_memory_reserved:39.03515625
Step 15 | loss:nan lr:2e-05 tokens_per_second_per_gpu:7.682432651519775 peak_memory_active:27.47649908065796 peak_memory_alloc:27.47649908065796 peak_memory_reserved:39.03515625
Step 16 | loss:nan lr:2e-05 tokens_per_second_per_gpu:11.597535133361816 peak_memory_active:27.44224739074707 peak_memory_alloc:27.44224739074707 peak_memory_reserved:39.03515625
Step 17 | loss:nan lr:2e-05 tokens_per_second_per_gpu:4.915543079376221 peak_memory_active:27.41466236114502 peak_memory_alloc:27.41466236114502 peak_memory_reserved:39.03515625
Step 18 | loss:nan lr:2e-05 tokens_per_second_per_gpu:11.867782592773438 peak_memory_active:27.400733947753906 peak_memory_alloc:27.400733947753906 peak_memory_reserved:39.03515625
Step 19 | loss:nan lr:2e-05 tokens_per_second_per_gpu:11.128471374511719 peak_memory_active:27.4446063041687 peak_memory_alloc:27.4446063041687 peak_memory_reserved:39.03515625
Step 20 | loss:nan lr:2e-05 tokens_per_second_per_gpu:14.666069030761719 peak_memory_active:27.431589126586914 peak_memory_alloc:27.431589126586914 peak_memory_reserved:39.03515625
Step 21 | loss:nan lr:2e-05 tokens_per_second_per_gpu:10.821459770202637 peak_memory_active:27.424962043762207 peak_memory_alloc:27.424962043762207 peak_memory_reserved:39.037109375
Step 22 | loss:nan lr:2e-05 tokens_per_second_per_gpu:8.869871139526367 peak_memory_active:27.419812202453613 peak_memory_alloc:27.419812202453613 peak_memory_reserved:39.037109375
└─ $ tune_logs_tps log_1747916676.txt
tps_avg: 8.70872 n= 24
peak_memory_alloc max 27.54
peak_memory_reser max 39.04
:link: Helpful Links
:test_tube: See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2755
- :page_facing_up: Preview Python docs built from this PR
Note: Links to docs will display an error until the docs builds have been completed.
:x: 1 New Failure
As of commit 102ee6d4ad62954fc87d579cc3d37d3002ef2fe8 with merge base 23b3f7b421ff891c782d021021fed328c6509adc ():
NEW FAILURE - The following job has failed:
- GPU tests / gpu_test (3.9, stable) (gh)
tests/recipes/test_lora_dpo_distributed.py::TestLoRADPODistributedRecipe::test_training_state_on_resume[True]
This comment was automatically generated by Dr. CI and updates every 15 minutes.
Codecov Report
Attention: Patch coverage is 18.34862% with 89 lines in your changes missing coverage. Please review.
Please upload report for BASE (
gh/IvanKobzarev/5/base@428f301). Learn more about missing BASE report.
Additional details and impacted files
@@ Coverage Diff @@
## gh/IvanKobzarev/5/base #2755 +/- ##
=========================================================
Coverage ? 59.91%
=========================================================
Files ? 437
Lines ? 26848
Branches ? 0
=========================================================
Hits ? 16085
Misses ? 10763
Partials ? 0
:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.
:rocket: New features to boost your workflow:
- :snowflake: Test Analytics: Detect flaky tests, report on failures, and find test suite problems.