Alp Dener
Alp Dener
This PR moves all the userbuffers code in TE/pytorch to TE/common and refactors the interfaces to make TE/common/userbuffers accessible to all framework integrations. **To do:** - [x] Move userbuffers from...
`torch.distributed.fsdp.FullyShardedDataParallel` cannot scatter/gather the intermediate/activation tensors that TE modules pack into the autograd context at the end of their forward passes, resulting in *globally sized* activation and Fp8 weight tensors...
This PR modifies `te.distributed.checkpoint(...)` to preserve the `torch.amp.autocast(...)` context from the forward pass during the recompute phase. Reported in #787.
TorchDynamo has known limitations for `autograd.Function` implementations and `autograd.graph` hooks. Activation recompute utilizes *both* of those mechanisms, so this PR disables TorchDynamo on `te.distributed.checkpoint()` via the `@no_torch_dynamo()` decorator.
# Description This PR moves Userbuffers and comm+GEMM overlap algorithms from TE/PyTorch to TE/common with refactored interfaces to remove the PyTorch dependency. ## Type of change - [ ] Documentation...
# Description Userbuffers configuration dictionary now has new `"comm_priority": -1` and `"gemm_priority: -1` options with default values of -1 for both. Attn. @rachitgarg91 ## Type of change - [ ]...
# Description In cases where `initialize_ub()`+`destroy_ub()` pairs are called more than once (e.g. in-process restarts), the cuBLAS workspace allocation is mishandled and grows exponentially. This PR safeguards the workspace expansion...
# Description When Userbuffers config dictionary sets overlap method to `ring-exchange` or `pipeline` for any `*_dgrad` layer, that layer's `*_wgrad` overlap needs to be disabled in order for `ub_overlap_rs_dgrad=True` option...
# Description Implements both old-style and new FFI-based XLA custom calls in C++, and the corresponding JAX primitive including custom partitioning rules. Custom partitioning rules for a `LHS:([B,] M, K)...
# Description This PR integrates TE/common cuBlasMp bindings into the TE/JAX CollectiveGemm custom op. ## Type of change - [ ] Documentation change (change only to the documentation, either a...