fix _te_version issue in transformer_engine.py get_cpu_offload_context()
issue #988 When the version of transformer-engine is higher than 1.8.0 and lower than 1.10.0, some bugs in transformer_engine.py get_cpu_offload_context() https://github.com/NVIDIA/Megatron-LM/blob/094d66b488514beaac2106c3e0f9581d27ea9533/megatron/core/transformer/custom_layers/transformer_engine.py#L890-L904
[rank1]: Traceback (most recent call last):
[rank1]: File "/workspace/megatron/examples/run_simple_mcore_train_loop.py", line 121, in <module>
[rank1]: gpt_model = model_provider()
[rank1]: File "/workspace/megatron/examples/run_simple_mcore_train_loop.py", line 47, in model_provider
[rank1]: gpt_model = GPTModel(
[rank1]: File "/workspace/megatron/megatron/core/models/gpt/gpt_model.py", line 101, in __init__
[rank1]: self.decoder = TransformerBlock(
[rank1]: File "/workspace/megatron/megatron/core/transformer/transformer_block.py", line 148, in __init__
[rank1]: get_cpu_offload_context(
[rank1]: File "/workspace/megatron/megatron/core/transformer/custom_layers/transformer_engine.py", line 898, in get_cpu_offload_context
[rank1]: context, sync_func = _get_cpu_offload_context(
[rank1]: **TypeError: get_cpu_offload_context() takes from 0 to 4 positional arguments but 5 were given**
I have checked the source code of transformer-engine. Only when the version of transformer-engine is higher than 1.9.0, function get_cpu_offload_context() need 5 args, so 1.8.0 in the code should be modified to 1.9.0.
from transformer_engine.pytorch.cpu_offload import (
get_cpu_offload_context as _get_cpu_offload_context,
)
def get_cpu_offload_context(
enabled, num_layers, model_layers, activation_offloading, weight_offloading
):
if _te_version > packaging.version.Version("1.9.0"):
context, sync_func = _get_cpu_offload_context(
enabled, num_layers, model_layers, activation_offloading, weight_offloading
)
else:
context, sync_func = _get_cpu_offload_context(
enabled, num_layers, activation_offloading, weight_offloading
)
supplement: branch release_v1.9 of TransformerEngine get_cpu_offload_context() https://github.com/NVIDIA/TransformerEngine/blob/ba36f90d05c203787294b7e490af901d79f07d30/transformer_engine/pytorch/cpu_offload.py#L482 branch main(version 1.10.0.dev0) of TransformerEngine get_cpu_offload_context() https://github.com/NVIDIA/TransformerEngine/blob/def4d1cbfd24e4bb28608d045634a817f638abb7/transformer_engine/pytorch/cpu_offload.py#L438
@akoumpa
Thanks @1195343015 for the contribution we will include your PR soon. Thanks again.
Fixed in https://github.com/NVIDIA/Megatron-LM/commit/98b43c91d004dec254f1610d9cffae8aff8550f3.