Alp Dener

Results 11 issues of Alp Dener

This PR moves all the userbuffers code in TE/pytorch to TE/common and refactors the interfaces to make TE/common/userbuffers accessible to all framework integrations. **To do:** - [x] Move userbuffers from...

`torch.distributed.fsdp.FullyShardedDataParallel` cannot scatter/gather the intermediate/activation tensors that TE modules pack into the autograd context at the end of their forward passes, resulting in *globally sized* activation and Fp8 weight tensors...

1.7.0

This PR modifies `te.distributed.checkpoint(...)` to preserve the `torch.amp.autocast(...)` context from the forward pass during the recompute phase. Reported in #787.

bug
1.7.0

TorchDynamo has known limitations for `autograd.Function` implementations and `autograd.graph` hooks. Activation recompute utilizes *both* of those mechanisms, so this PR disables TorchDynamo on `te.distributed.checkpoint()` via the `@no_torch_dynamo()` decorator.

# Description This PR moves Userbuffers and comm+GEMM overlap algorithms from TE/PyTorch to TE/common with refactored interfaces to remove the PyTorch dependency. ## Type of change - [ ] Documentation...

enhancement

# Description Userbuffers configuration dictionary now has new `"comm_priority": -1` and `"gemm_priority: -1` options with default values of -1 for both. Attn. @rachitgarg91 ## Type of change - [ ]...

enhancement

# Description In cases where `initialize_ub()`+`destroy_ub()` pairs are called more than once (e.g. in-process restarts), the cuBLAS workspace allocation is mishandled and grows exponentially. This PR safeguards the workspace expansion...

bug
2.3.0

# Description When Userbuffers config dictionary sets overlap method to `ring-exchange` or `pipeline` for any `*_dgrad` layer, that layer's `*_wgrad` overlap needs to be disabled in order for `ub_overlap_rs_dgrad=True` option...

bug

# Description Implements both old-style and new FFI-based XLA custom calls in C++, and the corresponding JAX primitive including custom partitioning rules. Custom partitioning rules for a `LHS:([B,] M, K)...

jax

# Description This PR integrates TE/common cuBlasMp bindings into the TE/JAX CollectiveGemm custom op. ## Type of change - [ ] Documentation change (change only to the documentation, either a...

2.10.0