torchtune icon indicating copy to clipboard operation
torchtune copied to clipboard

[llama4] use grouped_mm in moe for sm90

Open IvanKobzarev opened this issue 6 months ago • 2 comments

Stack from ghstack (oldest at bottom):

  • #2756
  • -> #2755

Enablement copy pasted from torchtitan. There is only one difference - output of grouped_mm is unintialized after offsets[-1] (result of at::empty)

        out = _grouped_mm(h, w2, offs=offsets)
        out[offsets[-1] :].zero_()

scatter_add does not have ignore indices and default 0 in token_indices will make it copied to the 0 index. To prevent copying unitialized data - zero_() after grouped_mm for now.

grouped_mm in eager: eager tps_avg: 3.2182 -> 4.44

grouped_mm in compile: compile tps_avg: 4.33463 -> 8.70872

Observing NaN loss in compile :( with/without grouped_mm - will be debugging this separately

no_compile + no grouped_mm:
BASELINE:

Step 1 | loss:4.149928092956543 lr:2e-05 tokens_per_second_per_gpu:2.203409433364868 peak_memory_active:28.82542610168457 peak_memory_alloc:28.82542610168457 peak_memory_reserved:40.197265625
Step 2 | loss:2.3132553100585938 lr:2e-05 tokens_per_second_per_gpu:2.492661952972412 peak_memory_active:28.804787158966064 peak_memory_alloc:28.804787158966064 peak_memory_reserved:40.251953125
Step 3 | loss:1.6078141927719116 lr:2e-05 tokens_per_second_per_gpu:3.8224945068359375 peak_memory_active:28.853728771209717 peak_memory_alloc:28.853728771209717 peak_memory_reserved:40.251953125
Step 4 | loss:1.4519065618515015 lr:2e-05 tokens_per_second_per_gpu:3.001920700073242 peak_memory_active:28.84491729736328 peak_memory_alloc:28.84491729736328 peak_memory_reserved:40.251953125
Step 5 | loss:1.2131776809692383 lr:2e-05 tokens_per_second_per_gpu:3.3134357929229736 peak_memory_active:28.804935932159424 peak_memory_alloc:28.804935932159424 peak_memory_reserved:40.251953125
Step 6 | loss:1.411360263824463 lr:2e-05 tokens_per_second_per_gpu:4.279047966003418 peak_memory_active:28.83890438079834 peak_memory_alloc:28.83890438079834 peak_memory_reserved:40.251953125
Step 7 | loss:1.1743241548538208 lr:2e-05 tokens_per_second_per_gpu:3.129912853240967 peak_memory_active:28.785582065582275 peak_memory_alloc:28.785582065582275 peak_memory_reserved:40.251953125
Step 8 | loss:1.3950272798538208 lr:2e-05 tokens_per_second_per_gpu:3.4216036796569824 peak_memory_active:28.967872619628906 peak_memory_alloc:28.967872619628906 peak_memory_reserved:40.251953125
Step 9 | loss:1.2500101327896118 lr:2e-05 tokens_per_second_per_gpu:3.410867214202881 peak_memory_active:28.854501247406006 peak_memory_alloc:28.854501247406006 peak_memory_reserved:40.251953125
Step 10 | loss:1.1264036893844604 lr:2e-05 tokens_per_second_per_gpu:3.1065995693206787 peak_memory_active:28.782616138458252 peak_memory_alloc:28.782616138458252 peak_memory_reserved:40.251953125
 /tmp/torchtune/llama4_17Bx16E/full/logs
└─ $ tune_logs_tps log_1747910831.txt
tps_avg: 3.2182  n= 10
peak_memory_alloc max   28.97
peak_memory_reser max   40.25





no_compile + grouped_mm:
______

EAGER_GROUPED_MM:
└─ $ cat log_1747915295.txt
Step 1 | loss:4.163714408874512 lr:2e-05 tokens_per_second_per_gpu:2.171847343444824 peak_memory_active:27.444169998168945 peak_memory_alloc:27.444169998168945 peak_memory_reserved:38.9609375
Step 2 | loss:2.283238172531128 lr:2e-05 tokens_per_second_per_gpu:3.158111810684204 peak_memory_active:27.428916931152344 peak_memory_alloc:27.428916931152344 peak_memory_reserved:39.017578125
Step 3 | loss:1.5704680681228638 lr:2e-05 tokens_per_second_per_gpu:4.869690418243408 peak_memory_active:27.469550132751465 peak_memory_alloc:27.469550132751465 peak_memory_reserved:39.017578125
Step 4 | loss:1.4485288858413696 lr:2e-05 tokens_per_second_per_gpu:4.026949405670166 peak_memory_active:27.46179723739624 peak_memory_alloc:27.46179723739624 peak_memory_reserved:39.017578125
Step 5 | loss:1.1918020248413086 lr:2e-05 tokens_per_second_per_gpu:4.502264976501465 peak_memory_active:27.428401947021484 peak_memory_alloc:27.428401947021484 peak_memory_reserved:39.017578125
Step 6 | loss:1.3454418182373047 lr:2e-05 tokens_per_second_per_gpu:5.665891170501709 peak_memory_active:27.456215858459473 peak_memory_alloc:27.456215858459473 peak_memory_reserved:39.017578125
Step 7 | loss:1.185951590538025 lr:2e-05 tokens_per_second_per_gpu:4.356126308441162 peak_memory_active:27.412775993347168 peak_memory_alloc:27.412775993347168 peak_memory_reserved:39.017578125
Step 8 | loss:1.3995944261550903 lr:2e-05 tokens_per_second_per_gpu:4.84677267074585 peak_memory_active:27.562548637390137 peak_memory_alloc:27.562548637390137 peak_memory_reserved:39.017578125
Step 9 | loss:1.2516542673110962 lr:2e-05 tokens_per_second_per_gpu:4.785239219665527 peak_memory_active:27.469550132751465 peak_memory_alloc:27.469550132751465 peak_memory_reserved:39.017578125
Step 10 | loss:1.1245404481887817 lr:2e-05 tokens_per_second_per_gpu:4.080092906951904 peak_memory_active:27.41029644012451 peak_memory_alloc:27.41029644012451 peak_memory_reserved:39.017578125
Step 11 | loss:1.2505868673324585 lr:2e-05 tokens_per_second_per_gpu:4.96058988571167 peak_memory_active:27.416271209716797 peak_memory_alloc:27.416271209716797 peak_memory_reserved:39.0234375
Step 12 | loss:1.0976852178573608 lr:2e-05 tokens_per_second_per_gpu:4.555960655212402 peak_memory_active:27.415279388427734 peak_memory_alloc:27.415279388427734 peak_memory_reserved:39.0234375
Step 13 | loss:1.176978349685669 lr:2e-05 tokens_per_second_per_gpu:5.802102565765381 peak_memory_active:27.413253784179688 peak_memory_alloc:27.413253784179688 peak_memory_reserved:39.0234375
 /tmp/torchtune/llama4_17Bx16E/full/logs
└─ $ tune_logs_tps log_1747915295.txt
tps_avg: 4.44474  n= 13
peak_memory_alloc max   27.56
peak_memory_reser max   39.02








compile + no_grouped_mm:

COMPILED BASELINE:

└─ $ cat log_1747915995.txt
Step 1 | loss:4.1595354080200195 lr:2e-05 tokens_per_second_per_gpu:0.42956066131591797 peak_memory_active:28.81977891921997 peak_memory_alloc:28.81977891921997 peak_memory_reserved:40.236328125
Step 2 | loss:2.2778103351593018 lr:2e-05 tokens_per_second_per_gpu:2.819465160369873 peak_memory_active:28.79981756210327 peak_memory_alloc:28.79981756210327 peak_memory_reserved:40.2890625
Step 3 | loss:1.786258578300476 lr:2e-05 tokens_per_second_per_gpu:5.622256755828857 peak_memory_active:28.846065998077393 peak_memory_alloc:28.846065998077393 peak_memory_reserved:40.2890625
Step 4 | loss:1.600595474243164 lr:2e-05 tokens_per_second_per_gpu:2.996225357055664 peak_memory_active:28.837817668914795 peak_memory_alloc:28.837817668914795 peak_memory_reserved:40.2890625
Step 5 | loss:1.627750039100647 lr:2e-05 tokens_per_second_per_gpu:6.902056694030762 peak_memory_active:28.79964590072632 peak_memory_alloc:28.79964590072632 peak_memory_reserved:40.2890625
Step 6 | loss:1.6751512289047241 lr:2e-05 tokens_per_second_per_gpu:3.8652405738830566 peak_memory_active:28.833858966827393 peak_memory_alloc:28.833858966827393 peak_memory_reserved:40.2890625
Step 7 | loss:1.5953153371810913 lr:2e-05 tokens_per_second_per_gpu:2.8824896812438965 peak_memory_active:28.781311511993408 peak_memory_alloc:28.781311511993408 peak_memory_reserved:40.2890625
Step 8 | loss:1.6390446424484253 lr:2e-05 tokens_per_second_per_gpu:4.305164813995361 peak_memory_active:28.948315620422363 peak_memory_alloc:28.948315620422363 peak_memory_reserved:40.2890625
Step 9 | loss:1.53915536403656 lr:2e-05 tokens_per_second_per_gpu:4.096757888793945 peak_memory_active:28.84673547744751 peak_memory_alloc:28.84673547744751 peak_memory_reserved:40.2890625
Step 10 | loss:1.4715595245361328 lr:2e-05 tokens_per_second_per_gpu:6.167354106903076 peak_memory_active:28.778648853302002 peak_memory_alloc:28.778648853302002 peak_memory_reserved:40.2890625
Step 11 | loss:nan lr:2e-05 tokens_per_second_per_gpu:4.351371765136719 peak_memory_active:28.787389278411865 peak_memory_alloc:28.787389278411865 peak_memory_reserved:40.2890625
 /tmp/torchtune/llama4_17Bx16E/full/logs
└─ $ tune_logs_tps log_1747915995.txt
tps_avg: 4.33463  n= 12
peak_memory_alloc max   28.95
peak_memory_reser max   40.29




COMPILED + GROUPED_MM



└─ $ cat log_1747916676.txt
Step 1 | loss:4.163997650146484 lr:2e-05 tokens_per_second_per_gpu:0.9783019423484802 peak_memory_active:27.437111377716064 peak_memory_alloc:27.437111377716064 peak_memory_reserved:38.978515625
Step 2 | loss:2.2560360431671143 lr:2e-05 tokens_per_second_per_gpu:2.861786365509033 peak_memory_active:27.42238759994507 peak_memory_alloc:27.42238759994507 peak_memory_reserved:39.03515625
Step 3 | loss:1.8066692352294922 lr:2e-05 tokens_per_second_per_gpu:7.731247901916504 peak_memory_active:27.45892906188965 peak_memory_alloc:27.45892906188965 peak_memory_reserved:39.03515625
Step 4 | loss:1.7610063552856445 lr:2e-05 tokens_per_second_per_gpu:6.10059928894043 peak_memory_active:27.452006340026855 peak_memory_alloc:27.452006340026855 peak_memory_reserved:39.03515625
Step 5 | loss:1.4565016031265259 lr:2e-05 tokens_per_second_per_gpu:10.525087356567383 peak_memory_active:27.42187261581421 peak_memory_alloc:27.42187261581421 peak_memory_reserved:39.03515625
Step 6 | loss:nan lr:2e-05 tokens_per_second_per_gpu:8.103992462158203 peak_memory_active:27.448771476745605 peak_memory_alloc:27.448771476745605 peak_memory_reserved:39.03515625
Step 7 | loss:nan lr:2e-05 tokens_per_second_per_gpu:7.136734962463379 peak_memory_active:27.407429218292236 peak_memory_alloc:27.407429218292236 peak_memory_reserved:39.03515625
Step 8 | loss:nan lr:2e-05 tokens_per_second_per_gpu:11.029473304748535 peak_memory_active:27.543721675872803 peak_memory_alloc:27.543721675872803 peak_memory_reserved:39.03515625
Step 9 | loss:nan lr:2e-05 tokens_per_second_per_gpu:8.074372291564941 peak_memory_active:27.45892906188965 peak_memory_alloc:27.45892906188965 peak_memory_reserved:39.03515625
Step 10 | loss:nan lr:2e-05 tokens_per_second_per_gpu:8.008955955505371 peak_memory_active:27.40485429763794 peak_memory_alloc:27.40485429763794 peak_memory_reserved:39.03515625
Step 11 | loss:nan lr:2e-05 tokens_per_second_per_gpu:9.958040237426758 peak_memory_active:27.411057949066162 peak_memory_alloc:27.411057949066162 peak_memory_reserved:39.03515625
Step 12 | loss:nan lr:2e-05 tokens_per_second_per_gpu:9.930673599243164 peak_memory_active:27.410027980804443 peak_memory_alloc:27.410027980804443 peak_memory_reserved:39.03515625
Step 13 | loss:nan lr:2e-05 tokens_per_second_per_gpu:13.09176254272461 peak_memory_active:27.406914234161377 peak_memory_alloc:27.406914234161377 peak_memory_reserved:39.03515625
Step 14 | loss:nan lr:2e-05 tokens_per_second_per_gpu:1.858479619026184 peak_memory_active:27.490173816680908 peak_memory_alloc:27.490173816680908 peak_memory_reserved:39.03515625
Step 15 | loss:nan lr:2e-05 tokens_per_second_per_gpu:7.682432651519775 peak_memory_active:27.47649908065796 peak_memory_alloc:27.47649908065796 peak_memory_reserved:39.03515625
Step 16 | loss:nan lr:2e-05 tokens_per_second_per_gpu:11.597535133361816 peak_memory_active:27.44224739074707 peak_memory_alloc:27.44224739074707 peak_memory_reserved:39.03515625
Step 17 | loss:nan lr:2e-05 tokens_per_second_per_gpu:4.915543079376221 peak_memory_active:27.41466236114502 peak_memory_alloc:27.41466236114502 peak_memory_reserved:39.03515625
Step 18 | loss:nan lr:2e-05 tokens_per_second_per_gpu:11.867782592773438 peak_memory_active:27.400733947753906 peak_memory_alloc:27.400733947753906 peak_memory_reserved:39.03515625
Step 19 | loss:nan lr:2e-05 tokens_per_second_per_gpu:11.128471374511719 peak_memory_active:27.4446063041687 peak_memory_alloc:27.4446063041687 peak_memory_reserved:39.03515625
Step 20 | loss:nan lr:2e-05 tokens_per_second_per_gpu:14.666069030761719 peak_memory_active:27.431589126586914 peak_memory_alloc:27.431589126586914 peak_memory_reserved:39.03515625
Step 21 | loss:nan lr:2e-05 tokens_per_second_per_gpu:10.821459770202637 peak_memory_active:27.424962043762207 peak_memory_alloc:27.424962043762207 peak_memory_reserved:39.037109375
Step 22 | loss:nan lr:2e-05 tokens_per_second_per_gpu:8.869871139526367 peak_memory_active:27.419812202453613 peak_memory_alloc:27.419812202453613 peak_memory_reserved:39.037109375


└─ $ tune_logs_tps log_1747916676.txt
tps_avg: 8.70872  n= 24
peak_memory_alloc max   27.54
peak_memory_reser max   39.04


IvanKobzarev avatar May 21 '25 14:05 IvanKobzarev

:link: Helpful Links

:test_tube: See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2755

Note: Links to docs will display an error until the docs builds have been completed.

:x: 1 New Failure

As of commit 102ee6d4ad62954fc87d579cc3d37d3002ef2fe8 with merge base 23b3f7b421ff891c782d021021fed328c6509adc (image):

NEW FAILURE - The following job has failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

pytorch-bot[bot] avatar May 21 '25 14:05 pytorch-bot[bot]

Codecov Report

Attention: Patch coverage is 18.34862% with 89 lines in your changes missing coverage. Please review.

Please upload report for BASE (gh/IvanKobzarev/5/base@428f301). Learn more about missing BASE report.

Files with missing lines Patch % Lines
torchtune/modules/moe/indices.py 0.00% 57 Missing :warning:
torchtune/modules/moe/moe.py 5.26% 18 Missing :warning:
torchtune/modules/moe/experts.py 44.44% 10 Missing :warning:
recipes/full_finetune_distributed.py 0.00% 3 Missing :warning:
torchtune/modules/moe/config.py 90.00% 1 Missing :warning:
Additional details and impacted files
@@                    Coverage Diff                    @@
##             gh/IvanKobzarev/5/base    #2755   +/-   ##
=========================================================
  Coverage                          ?   59.91%           
=========================================================
  Files                             ?      437           
  Lines                             ?    26848           
  Branches                          ?        0           
=========================================================
  Hits                              ?    16085           
  Misses                            ?    10763           
  Partials                          ?        0           

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.

:rocket: New features to boost your workflow:
  • :snowflake: Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

codecov-commenter avatar May 22 '25 13:05 codecov-commenter