ThunderFX fails with FP8 and Activation Checkpointing
🐛 Bug
When training models: 'vicuna-7b-v1.5-16k', 'longchat-13b-16k', 'Mistral-7B-v0.2', 'falcon-180B', 'Llama-3-70B', 'CodeLlama-34b-hf' with FSDP and FP8 we get KeyError: 'scaling_fwd'. This might be also issue with Transformer Engine,, so I'm happy to move this issue to TE if needed.
Full traceback:
[rank7]: Traceback (most recent call last): 7: [rank7]: File "/workspace/lightning-thunder/thunder/benchmarks/benchmark_litgpt.py", line 974, in
7: [rank7]: CLI(benchmark_main) 7: [rank7]: File "/usr/local/lib/python3.12/dist-packages/jsonargparse/_cli.py", line 96, in CLI 7: [rank7]: return _run_component(components, init) 7: [rank7]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 7: [rank7]: File "/usr/local/lib/python3.12/dist-packages/jsonargparse/_cli.py", line 204, in _run_component 7: [rank7]: return component(**cfg) 7: [rank7]: ^^^^^^^^^^^^^^^^ 7: [rank7]: File "/workspace/lightning-thunder/thunder/benchmarks/benchmark_litgpt.py", line 871, in benchmark_main 7: [rank7]: benchmark.train() 7: [rank7]: File "/workspace/lightning-thunder/thunder/benchmarks/benchmark_litgpt.py", line 765, in train 7: [rank7]: loss.backward() 7: [rank7]: File "/usr/local/lib/python3.12/dist-packages/torch/_tensor.py", line 624, in backward 7: [rank7]: torch.autograd.backward( 7: [rank7]: File "/usr/local/lib/python3.12/dist-packages/torch/autograd/init.py", line 347, in backward 7: [rank7]: _engine_run_backward( 7: [rank7]: File "/usr/local/lib/python3.12/dist-packages/torch/autograd/graph.py", line 825, in _engine_run_backward 7: [rank7]: return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass 7: [rank7]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 7: [rank7]: File "/usr/local/lib/python3.12/dist-packages/torch/autograd/function.py", line 307, in apply 7: [rank7]: return user_fn(self, *args) 7: [rank7]: ^^^^^^^^^^^^^^^^^^^^ 7: [rank7]: File "/usr/local/lib/python3.12/dist-packages/torch/autograd/function.py", line 600, in wrapper 7: [rank7]: outputs = fn(ctx, *args) 7: [rank7]: ^^^^^^^^^^^^^^ 7: [rank7]: File "/opt/pytorch/lightning-thunder/thunder/executors/torch_autograd.py", line 115, in backward 7: [rank7]: grads = ctx.compiled_backward([saved_tensors_list, ctx.saved_other], args) 7: [rank7]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 7: [rank7]: File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context 7: [rank7]: return func(*args, **kwargs) 7: [rank7]: ^^^^^^^^^^^^^^^^^^^^^ 7: [rank7]: File "/usr/local/lib/python3.12/dist-packages/torch/amp/autocast_mode.py", line 44, in decorate_autocast 7: [rank7]: return func(*args, **kwargs) 7: [rank7]: ^^^^^^^^^^^^^^^^^^^^^ 7: [rank7]: File "/usr/local/lib/python3.12/dist-packages/torch/amp/autocast_mode.py", line 44, in decorate_autocast 7: [rank7]: return func(*args, **kwargs) 7: [rank7]: ^^^^^^^^^^^^^^^^^^^^^ 7: [rank7]: File "thunder.backward_fn_13", line 28, in backward_fn 7: [rank7]: File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl 7: [rank7]: return self.call_impl(*args, **kwargs) 7: [rank7]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 7: [rank7]: File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1747, in call_impl 7: [rank7]: return forward_call(*args, **kwargs) 7: [rank7]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 7: [rank7]: File "/opt/pytorch/lightning-thunder/thunder/executors/transformer_engineex.py", line 205, in forward 7: [rank7]: weight_fp8, weight_t_fp8 = self.get_fp8_weight_version_compat( 7: [rank7]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 7: [rank7]: File "/opt/pytorch/lightning-thunder/thunder/executors/transformer_engineex.py", line 273, in get_fp8_weight_version_compat 7: [rank7]: weight_fp8 = self.get_fp8_workspace( 7: [rank7]: ^^^^^^^^^^^^^^^^^^^^^^^ 7: [rank7]: File "/usr/local/lib/python3.12/dist-packages/transformer_engine/pytorch/module/base.py", line 1086, in get_fp8_workspace 7: [rank7]: out.quantize(tensor, noop_flag=skip_update_flag) 7: [rank7]: File "/usr/local/lib/python3.12/dist-packages/transformer_engine/pytorch/tensor/float8_tensor.py", line 642, in quantize 7: [rank7]: fp8_meta = dst._fp8_meta[fp8_meta_key] 7: [rank7]: ~~~~~~~~~~~~~^^^^^^^^^^^^^^ 7: [rank7]: KeyError: 'scaling_fwd'
To Reproduce
Please use:
1 node(s), each with 8 GPUs.
Image "INTERNAL_IMAGE:pjnl_20241107"
Training script:
python /opt/pytorch/lightning-thunder/thunder/benchmarks/benchmark_litgpt.py
--model_name Mistral-7B-v0.2
--distributed_mode fsdp
--shard_mode zero2
--compile dynamo_thunder
--checkpoint_activations True
--low_precision_mode fp8-delayed-te
--micro_batch_size 1
Environment
system.device_product_name DGXH100 system.gpu_driver_version 535.129.03 libraries.cuda 12.6.98.001 libraries.pip.lightning 2.4.0.dev20240728 libraries.pip.lightning-thunder 0.2.0.dev0 libraries.pip.lightning-utilities 0.11.8 libraries.pip.litgpt 0.4.11 libraries.pip.nvfuser 0.2.22+gitba4f7d4 libraries.pip.pytorch-lightning 2.4.0 libraries.pip.torch 2.6.0a0+gita9b4989 libraries.pip.torchao 0.6.1 libraries.pip.torchmetrics 1.5.1 libraries.pip.torchvision 0.19.0a0+d23a6e1
This seems to be happening due to interaction of TransformerEngine and checkpointing.
Minimal Repro
import torch
import torch.utils.checkpoint
def checkpointed_fn(x):
y = x.cos()
return torch.nn.functional.linear(x, y)
def fn(x):
return torch.utils.checkpoint.checkpoint(checkpointed_fn, x)
from thunder.dynamo import ThunderCompiler
from thunder.executors.transformer_engineex import transformer_engine_ex
import thunder
backend = ThunderCompiler(executors=[transformer_engine_ex,])
x = torch.randn(16, 16, device='cuda', requires_grad=True)
o = torch.compile(fn, backend=backend)(x)
assert len(backend.subgraph_infos) == 1
subgraph_info = backend.subgraph_infos[0]
tfn = subgraph_info.thunder_compiled_fns[0]
print(thunder.last_traces(tfn)[-1])
print(thunder.last_backward_traces(tfn)[-1])
o.sum().backward() # KeyError: 'scaling_fwd'
This happens because in the forward we are calling torch.nn.functional.linear but in the backward, we are calling te_functional_linear_backward (without ever calling the TE's forward).
Forward Graph
def computation(l_x_):
# l_x_: "cuda:0 f32[16, 16]"
t4 = torch.cos(l_x_) # t4: "cuda:0 f32[16, 16]"
# t4 = ltorch.cos(l_x_) # t4: "cuda:0 f32[16, 16]"
# t4 = prims.cos(l_x_) # t4: "cuda:0 f32[16, 16]"
getitem = torch.nn.functional.linear(l_x_, t4, None) # getitem: "cuda:0 f32[16, 16]"
# getitem = ltorch.linear(l_x_, t4, None) # getitem: "cuda:0 f32[16, 16]"
# getitem = prims.linear(l_x_, t4, None) # getitem: "cuda:0 f32[16, 16]"
del t4
return {'output': getitem, 'flat_args': [l_x_], 'flat_output': (getitem,)}, ((l_x_,), ())
Backward Graph
def backward_fn(saved_for_backward, cotangents):
# saved_for_backward: "Collection"
# cotangents: "Collection"
C0, _, = saved_for_backward
clear_mutable_collection(saved_for_backward)
del saved_for_backward
t0, = cotangents
clear_mutable_collection(cotangents)
del cotangents
l_x_, = C0
clear_mutable_collection(C0)
del C0
t6 = torch.cos(l_x_) # t6: "cuda:0 f32[16, 16]"
# t6 = ltorch.cos(l_x_) # t6: "cuda:0 f32[16, 16]"
# t6 = prims.cos(l_x_) # t6: "cuda:0 f32[16, 16]"
(_, (t10, t11, t12, t13, t14, _), ctx_te_1) = te_linear_0(l_x_, t6, None)
del t6
(t19, t20, _) = te_functional_linear_backward((16, 16), (16, 16), None, ctx_te_1, (t10, t11, t12, t13, t14, None), t0)
del ctx_te_1, t10, t11, t12, t13, t14, t0
t21 = torch.sin(l_x_) # t21: "cuda:0 f32[16, 16]"
# t21 = ltorch.sin(l_x_) # t21: "cuda:0 f32[16, 16]"
# t21 = prims.sin(l_x_) # t21: "cuda:0 f32[16, 16]"
del l_x_
t22 = torch.neg(t21) # t22: "cuda:0 f32[16, 16]"
# t22 = ltorch.neg(t21) # t22: "cuda:0 f32[16, 16]"
# t22 = prims.neg(t21) # t22: "cuda:0 f32[16, 16]"
del t21
t23 = torch.mul(t20, t22) # t23: "cuda:0 f32[16, 16]"
# t23 = ltorch.mul(t20, t22) # t23: "cuda:0 f32[16, 16]"
# t23 = prims.mul(t20, t22) # t23: "cuda:0 f32[16, 16]"
del t20, t22
t24 = torch.add(t19, t23) # t24: "cuda:0 f32[16, 16]"
# t24 = ltorch.add(t19, t23, alpha=1) # t24: "cuda:0 f32[16, 16]"
# t24 = prims.add(t19, t23) # t24: "cuda:0 f32[16, 16]"
del t19, t23
te_sync_fp8_meta_bwd()
return (t24,)
@kiya00 do you know why this could be happening? Thanks!
This happens because in the forward we are calling torch.nn.functional.linear but in the backward, we are calling te_functional_linear_backward (without ever calling the TE's forward).
https://github.com/Lightning-AI/lightning-thunder/blob/60f3ee1ec536ee8d6fdef503af54525e0a3978a4/thunder/torch/init.py#L5319-L5331
checkpointing uses vjp, is the te_linear_0 in the backward trace the original torch.nn.functional.linear ?
https://github.com/Lightning-AI/lightning-thunder/blob/60f3ee1ec536ee8d6fdef503af54525e0a3978a4/thunder/core/transforms.py#L2819 the input trace is:
@torch.no_grad()
@no_autocast
def flat_func(*flat_args):
# flat_args: "Collection"
t0, = flat_args
t1 = ltorch.cos(t0) # t1: "cuda:0 f32[16, 16]"
# t1 = prims.cos(t0) # t1: "cuda:0 f32[16, 16]"
t2 = ltorch.linear(t0, t1, None) # t2: "cuda:0 f32[16, 16]"
# t2 = prims.linear(t0, t1, None) # t2: "cuda:0 f32[16, 16]"
return (t2,)
and after L2819, it seems the linear becomes te_linear_0 in the backward_fn:
@torch.no_grad()
@no_autocast
def backward_fn(saved_for_backward, cotangents):
# saved_for_backward: "Collection"
# cotangents: "Collection"
C0, = saved_for_backward
t0, = cotangents
_torch_fx_graph_module_GraphModule___new___<locals>_GraphModuleImpl_0, C1, C2, \
= C0
l_x_, = C1
# C2 (empty dict)
t6 = prims.cos(l_x_) # t6: "cuda:0 f32[16, 16]"
t7 = ltorch.view(t0, (-1, 16)) # t7: "cuda:0 f32[16, 16]"
# t7 = ltorch.reshape(t0, (-1, 16)) # t7: "cuda:0 f32[16, 16]"
# t7 = prims.reshape(t0, (16, 16)) # t7: "cuda:0 f32[16, 16]"
_ = ltorch.dim(t7)
(_, _) = prims.shape(t7)
(_, _) = prims.shape(t7)
_ = ltorch.dim(t1)
(_, _) = prims.shape(t1)
(_, _) = prims.shape(t1)
t8 = ltorch.view(a, (-1, 16)) # t8: "cuda:0 f32[16, 16]"
# t8 = ltorch.reshape(a, (-1, 16)) # t8: "cuda:0 f32[16, 16]"
# t8 = prims.reshape(a, (16, 16)) # t8: "cuda:0 f32[16, 16]"
_ = ltorch.dim(t8)
(_, _) = prims.shape(t8)
(_, _) = prims.shape(t8)
_ = ltorch.dim(w)
(_, _) = prims.shape(w)
(_, _) = prims.shape(w)
t9 = ltorch.view(a, (-1, 16)) # t9: "cuda:0 f32[16, 16]"
# t9 = ltorch.reshape(a, (-1, 16)) # t9: "cuda:0 f32[16, 16]"
# t9 = prims.reshape(a, (16, 16)) # t9: "cuda:0 f32[16, 16]"
_ = ltorch.dim(t9)
(_, _) = prims.shape(t9)
(_, _) = prims.shape(t9)
_ = ltorch.dim(w)
(_, _) = prims.shape(w)
(_, _) = prims.shape(w)
(t15, (t10, t11, t12, t13, t14, _), ctx_te_1) = te_linear_0(l_x_, t6, None)
TransformerEngine has its own checkpoint function. How is it different from PyTorch's checkpoint? Can TransformerEngine's checkpoint be used with the existing executor in Thunder to do activation checkpointing?
Followup in tracking issue for functional TE executor #2027
@riccardofelluga, is this really completed in #1908?
@riccardofelluga, is this really completed in #1908?
No, let's keep it around till the new executor becomes default
@riccardofelluga - when does the new executor becomes default? What PR should I look to track that?
Verified solved after merging #2344.
Closing #2170 will close this issue.
This now works on main thanks to #2510