torchtitan icon indicating copy to clipboard operation
torchtitan copied to clipboard

enable TritonFusedRMSNorm with local_map annotation

Open XilunWu opened this issue 1 year ago • 1 comments

Stack from ghstack (oldest at bottom):

  • -> #364

XilunWu avatar May 25 '24 00:05 XilunWu

note: this test requires the land of https://github.com/pytorch/pytorch/pull/126924

XilunWu avatar May 25 '24 00:05 XilunWu

your perf benchmark seems using batch size =1, can you update with batch_size=4 and update the perf table

wanchaol avatar Jun 12 '24 22:06 wanchaol

@XilunWu The WPS for 8B in your summary still not looking right, I have exact same settings, but the WPS on my side is sth like this:

[rank0]:2024-06-13 12:53:16,156 - root - INFO - step:  1  loss: 12.2550  memory: 33.35GiB(35.09%)  wps: 542  mfu: 3.17%
[rank0]:2024-06-13 12:53:16,156 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:2024-06-13 12:53:42,562 - root - INFO - step: 10  loss: 10.7798  memory: 41.18GiB(43.33%)  wps: 2,792  mfu: 16.35%
[rank0]:2024-06-13 12:54:04,065 - root - INFO - step: 20  loss:  9.1087  memory: 41.18GiB(43.33%)  wps: 3,812  mfu: 22.32%
[rank0]:2024-06-13 12:54:25,626 - root - INFO - step: 30  loss:  7.9951  memory: 41.18GiB(43.33%)  wps: 3,802  mfu: 22.27%

wanchaol avatar Jun 13 '24 19:06 wanchaol