Megatron-LM
Megatron-LM copied to clipboard
[ENHANCEMENT] Integrating torch.compile with Megatron/TransformerEngine
Is your feature request related to a problem? Please describe.
- PyTorch 2.0 added a new feature
torch.compilewhich captures computation/communication ops into FX graph and generates optimized execution plan by fusing ops and leveraging computation/communication overlap. - Meanwhile, many cool features built on top of
torch.compile, likeFlexAttentionwhich provided a flexible API that can automatic generates high performance kernels for many attention variants. - We would like to explore if
torch.compile+ Megatron can unleash even great power at both LLM training and inference space.
Describe the solution you'd like
- Enable
torch.compileon top of Megatron modules, tensor parallel, context parallel and the underlying TransformerEngine, capture computation/communication graphs, investigate better fusion and computation/communication overlap optimization, etc. - Explore local compilation, e.g, integrating
FlexAttentioninto the Megatron attention module.
Additional context
- I looked code here and TransformerEngine repro, find many places have been decorated by
no_torch_dynamo, which just skipped Dynamo tracing. However, it seems Megatron supportstorch.compilebyjit_fuser = torch.compile. I'd like to know more context on the discrepancy of if allowingtorch.compilebetween these two repros. - I understand TransformerEngine has a lot of hand written CUDA kernels and fused modules, which already leverage fusion to squeeze model's performance. I think
torch.compilecan provide even more benefit than fusion only, like better fusion and leveraging computation/communication in the distributed setup, etc. - We would like to work with experts at here to figure out other opportunities of providing a better user experience and performance by integrating
torch.compilewith Megatron/TransformerEngine.
Hi, thanks for the issue. Is there anything more specific you'd like to contribute?
@ericharper Thanks for your reply! We have a list of enhancements that could contribute to Megatron/TransformerEngine, mainly focus on integrating torch.compile + Megatron/TransformerEngine. Some examples include:
- We already did many fixes to enhance
torch.compilefor capturing full graph for Megatron, now we almost done for TP and we are working on SP/CP/PP use cases. We think most of the issues should be fixed at PyTroch side, but we do find some very tricky cases need to work with Megatron community to figure out a better solution. So what we would directly contribute to this bucket includes: 1/ Adding torch compiled mode unit tests to ensure it works for major use cases; 2/ We would like to open a few issues about the cases that block us from capturing full graph (e.,g proper support fortorch.cuda.current_device), work with Megatron community to get feedback about our solution and finally fix these gaps. - We have some other features, like
FlexAttention, which provided a flexible API that can automatic generates high performance kernels for many attention variants. We believe this would help LLM researchers and engineers to quickly try their new attention score modification and masking ideas in an easy way. We'd like to check if this is aligned with Megatron community's vision, and we could add this to Megatron to have more users easily to access it. - In general, we are happy to work with Megatron community to enable any features around
torch.compile, to have users in LLM space can take advantage the cutting-edge technology from both Megatron and PyTorch.
Let me know if you have any question! Thank you!