TransformerEngine
TransformerEngine copied to clipboard
Selective Activation Checkpointing with LayerNormMLP
Hi all,
I was wondering whether it is possible to do selective activation checkpointing with the LayerNormMLP where we only recompute FFN1 and not FFN2, therefore not having to save the ffn1_out and gelu_out activations (the largest memory activations).
This has been done in OPT, https://github.com/facebookresearch/metaseq/blob/f7ffa5fd61cf90f498a36d365c13dd7f1a912ff7/metaseq/modules/sequence_parallel_transformer_layer.py#L250C20-L250C33 so I wonder if it is possible to do in TransformerEngine, because it would be awesome to use it with FP8!
Thank you!
@sudhakarsingh27 This is basically the same as what we discussed with improving the checkpoint logic to allow for the early stopping.