torchtitan
torchtitan copied to clipboard
selective compilation - norm layers only
This PR adds the option to selectively compile just the norm layers only, and is mainly targeted at RMSNorm. By compiling just the norm layers when using rmsnorm, we get nearly comparable speedups as using the fusedRMSNorm triton kernel. Credit @wconstab for this idea.
regular rmsnorm:
with the new compile_rmsnorm enabled:
2 - UX - I enabled the compile rmsnorm as it's own option for now so users can quickly try whole model or norm only compile. If compile is true, then the rmsnorm layers will not specifically be compiled (as they will be included in the generic full model compile) and a minor note is issued in logging.
3 - using other norms with this option enabled does not appear to add any speedup (but also no errors) so I did not add a check to only compile if norm is rmsnorm (but can add that).
Is 2 saying that in order to have "full" compile you need to set both compile=true and compile_rmsnorm = true
Is 2 saying that in order to have "full" compile you need to set both compile=true and compile_rmsnorm = true
I updated the text to be more specific, but no - if compile = true in the config, then you get full compile including the rmsnorm layers.
close as we removed the feature in #535