TransformerEngine icon indicating copy to clipboard operation
TransformerEngine copied to clipboard

Selective Activation Checkpointing with LayerNormMLP

Open denizokt opened this issue 1 year ago • 1 comments

Hi all,

I was wondering whether it is possible to do selective activation checkpointing with the LayerNormMLP where we only recompute FFN1 and not FFN2, therefore not having to save the ffn1_out and gelu_out activations (the largest memory activations).

This has been done in OPT, https://github.com/facebookresearch/metaseq/blob/f7ffa5fd61cf90f498a36d365c13dd7f1a912ff7/metaseq/modules/sequence_parallel_transformer_layer.py#L250C20-L250C33 so I wonder if it is possible to do in TransformerEngine, because it would be awesome to use it with FP8!

Thank you!

denizokt avatar Jan 22 '24 16:01 denizokt

@sudhakarsingh27 This is basically the same as what we discussed with improving the checkpoint logic to allow for the early stopping.

ptrendx avatar May 16 '24 19:05 ptrendx