torchtitan icon indicating copy to clipboard operation
torchtitan copied to clipboard

selective compilation - norm layers only

Open lessw2020 opened this issue 9 months ago • 2 comments

This PR adds the option to selectively compile just the norm layers only, and is mainly targeted at RMSNorm. By compiling just the norm layers when using rmsnorm, we get nearly comparable speedups as using the fusedRMSNorm triton kernel. Credit @wconstab for this idea.

regular rmsnorm: Screenshot 2024-05-09 at 5 09 17 PM

with the new compile_rmsnorm enabled: Screenshot 2024-05-09 at 5 24 39 PM

Screenshot 2024-05-09 at 5 09 57 PM

2 - UX - I enabled the compile rmsnorm as it's own option for now so users can quickly try whole model or norm only compile. If compile is true, then the rmsnorm layers will not specifically be compiled (as they will be included in the generic full model compile) and a minor note is issued in logging.

3 - using other norms with this option enabled does not appear to add any speedup (but also no errors) so I did not add a check to only compile if norm is rmsnorm (but can add that).

lessw2020 avatar May 10 '24 00:05 lessw2020

Is 2 saying that in order to have "full" compile you need to set both compile=true and compile_rmsnorm = true

drisspg avatar May 10 '24 00:05 drisspg

Is 2 saying that in order to have "full" compile you need to set both compile=true and compile_rmsnorm = true

I updated the text to be more specific, but no - if compile = true in the config, then you get full compile including the rmsnorm layers.

lessw2020 avatar May 10 '24 01:05 lessw2020

close as we removed the feature in #535

tianyu-l avatar Aug 21 '24 04:08 tianyu-l