TransformerEngine
TransformerEngine copied to clipboard
Improve import speed with lazy initialization
Currently importing transformer_engine takes ~10s on my machine and it also starts a background process pool because of all the JIT initialization like here .
It would be better if all the expensive parts were initialized lazily or at least when importing specific subpackage/module, but not when importing root level package.
This also affects e.g. "import accelerate" which uses transformer_engine, which is pretty annoying as I often use accelerate for workloads unrelated to transformer_engine.
@ksivaman Could you take a look at it? Maybe we could move jitting to the actual modules? In general I agree that we should not do anything CUDA-related on just import, but rather on the first use.
I've also noted that when importing certain submodules of the pytorch package it slows down so much that vscodes's debugger hangs :( . An example of this is from transformer_engine.pytorch.float8_tensor import Float8Tensor -> This induces already a hang for me when training with multiple processes.
Perhaps this is the reason why, looks like a module might be re-imported multiple times: https://github.com/microsoft/debugpy/issues/349#issuecomment-671536970
I encountered a similar hang problem. I used megatron-lm and after installing transformer-engine, I encountered a hang at torch.distributed.init. This problem will not occur after uninstalling.
I observed that during the hanging process, many sub processes were created.
Agreed, this takes more than 3 minutes on our server. Is there any way to bypass this issue?