TransformerEngine icon indicating copy to clipboard operation
TransformerEngine copied to clipboard

Fix param input order for cudagraph

Open yifeis-nv opened this issue 1 year ago • 2 comments

Description

I discovered that when I attempt to use cudagraph during Pipeline Parallelism, the gradient becomes incorrect, ultimately leading to a NaN issue. After debugging, I identified a small bug in TE's graph.py.

Fixes # (issue)

Here is the translation of your text into English for your GitHub issue description: Since the make_graphed_callables function in TE implements the backward graph through the torch.autograd.grad function, the weights are also passed into the torch.autograd.grad function through the inputs. This requires that the order of inputs in torch.autograd.grad matches the order in the forward graph; otherwise, it will lead to backward errors.

Type of change

  • [ ] Documentation change (change only to the documentation, either a fix or a new content)
  • [x] Bug fix (non-breaking change which fixes an issue)
  • [ ] New feature (non-breaking change which adds functionality)
  • [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • [ ] Infra/Build change
  • [ ] Code refractor

Changes

Modify the input order of weights inside of cudagraph related module

Checklist:

  • [x] I have read and followed the contributing guidelines
  • [x] The functionality is complete
  • [ ] I have commented my code, particularly in hard-to-understand areas
  • [ ] I have made corresponding changes to the documentation
  • [x] My changes generate no new warnings
  • [ ] I have added tests that prove my fix is effective or that my feature works
  • [ ] New and existing unit tests pass locally with my changes

yifeis-nv avatar Aug 27 '24 03:08 yifeis-nv

This fix seems plausible. It seems that make_graphed_callables expects sample_args to be ordered first by layer number, then by microbatch, then by model chunk:

https://github.com/NVIDIA/TransformerEngine/blob/0ece13ea3618bff300e9de304f32010480ddafce/transformer_engine/pytorch/graph.py#L236-L238

However, I see some of our MLPerf wrappers order by microbatch, then layer number, then model chunk: https://gitlab-master.nvidia.com/dl/mlperf/optimized/-/blob/main/large_language_model/pytorch/custom_callbacks.py#L249-L254 Pinging @ksivaman. Also, can you sign your commit to pass the DCO check?

THX for your reminder! I have signed my commit. Based on my understanding of the code, the order you referenced from MLPerf does not affect the capture order within the make_graphed_callables function. When performing the capture, it still follows the sequence of first by layer number, then by microbatch, and finally by model chunk. Therefore, the issue described earlier will still occur. I understand that this is why there is a modification in the code to isolate the captures of different microbatches (which will prevent sharing the memory pool and is likely to increase memory overhead): https://gitlab-master.nvidia.com/dl/mlperf/optimized/-/blob/main/large_language_model/pytorch/custom_callbacks.py#L216-237

yifeis-nv avatar Aug 28 '24 03:08 yifeis-nv

/te-ci pytorch

timmoon10 avatar Oct 03 '24 01:10 timmoon10