TransformerEngine
TransformerEngine copied to clipboard
FP8 & Activation checkpointing do not play well together
Activation checkpointing recomputes the activations and hence it will need to re-execute parts of forward pass.
This re-execution should not affect history and be allowed. Currently this error is being thrown: https://github.com/NVIDIA/TransformerEngine/pull/93.
Would it be possible to cover this case such that Activation checkpointing works seamlessly?
+1
Here is the error:
File "/miniconda/lib/python3.10/site-packages/torch/_tensor.py", line 491, in backward
torch.autograd.backward(
File "/miniconda/lib/python3.10/site-packages/torch/autograd/__init__.py", line 204, in backward
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
File "/miniconda/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 804, in unpack_hook
raise AssertionError(
AssertionError: if early stop is enabled, we don't expect to reach here
Without checkpointing it is hard to use the full compute capability of H100, because of not enough VRAM.
For clarity i am getting the following error when checkpointing TransformerLayer with fp8:
...
raceback (most recent call last):
File "/home/user/llm/llmte/te_model.py", line 129, in <module>
x.mean().backward()
File "/usr/local/lib/python3.10/dist-packages/torch/_tensor.py", line 491, in backward
torch.autograd.backward(
File "/usr/local/lib/python3.10/dist-packages/torch/autograd/__init__.py", line 204, in backward
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
File "/usr/local/lib/python3.10/dist-packages/torch/autograd/function.py", line 274, in apply
return user_fn(self, *args)
File "/usr/local/lib/python3.10/dist-packages/torch/utils/checkpoint.py", line 261, in backward
outputs = ctx.run_function(*detached_inputs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1505, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1514, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/transformer_engine/pytorch/transformer.py", line 499, in forward
self_attention_outputs = self.self_attention(
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1505, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1514, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/transformer_engine/pytorch/attention.py", line 1243, in forward
layernorm_qkv_outputs = self.layernorm_qkv(
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1505, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1514, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 294, in _fn
return fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/transformer_engine/pytorch/module/layernorm_linear.py", line 842, in forward
with self.prepare_forward(inp, is_first_microbatch) as inp:
File "/usr/lib/python3.10/contextlib.py", line 135, in __enter__
return next(self.gen)
File "/usr/local/lib/python3.10/dist-packages/transformer_engine/pytorch/module/base.py", line 612, in prepare_forward
add_amax_to_global_buffer(self.fp8_meta, forward=True)
File "/usr/local/lib/python3.10/dist-packages/transformer_engine/pytorch/fp8.py", line 135, in add_amax_to_global_buffer
assert fp8_meta[buffer_position_key] == len(_global_fp8_buffer[buffer_key]) - 1, \
AssertionError: Same module is being invoked more than once inside an `fp8_autocast` region when using FP8 with amax reduction. This behavior is currently unsupported. For more details and correct usage, please see https://github.com/NVIDIA/TransformerEngine/pull/93.
So I gess to fix this we just need to detect if we are within this checkpoint recomputation path, and if so, do not change the history at all.
How are you currently doing activation checkpointing? Are you using an underlying toolkit such as NeMo?
@ksivaman : I'm using checkpoint_wrapper
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper
coupled with FSDP.
For this purpose, you can use the checkpoint function that is provided in TransformerEngine. You can find the documentation here. Here we handle the additional items required for FP8 execution with activation recompute.
@ksivaman @denera : is there an example of using transformer_engine.pytorch.checkpoint
? Is it possible to add this to the FSDP example in TE?