torchtitan
torchtitan copied to clipboard
enable TritonFusedRMSNorm with local_map annotation
note: this test requires the land of https://github.com/pytorch/pytorch/pull/126924
your perf benchmark seems using batch size =1, can you update with batch_size=4 and update the perf table
@XilunWu The WPS for 8B in your summary still not looking right, I have exact same settings, but the WPS on my side is sth like this:
[rank0]:2024-06-13 12:53:16,156 - root - INFO - step: 1 loss: 12.2550 memory: 33.35GiB(35.09%) wps: 542 mfu: 3.17%
[rank0]:2024-06-13 12:53:16,156 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:2024-06-13 12:53:42,562 - root - INFO - step: 10 loss: 10.7798 memory: 41.18GiB(43.33%) wps: 2,792 mfu: 16.35%
[rank0]:2024-06-13 12:54:04,065 - root - INFO - step: 20 loss: 9.1087 memory: 41.18GiB(43.33%) wps: 3,812 mfu: 22.32%
[rank0]:2024-06-13 12:54:25,626 - root - INFO - step: 30 loss: 7.9951 memory: 41.18GiB(43.33%) wps: 3,802 mfu: 22.27%