lightning-thunder icon indicating copy to clipboard operation
lightning-thunder copied to clipboard

ThunderFX fails with FP8 and Activation Checkpointing

Open mpatel31415 opened this issue 1 year ago • 5 comments

🐛 Bug

When training models: 'vicuna-7b-v1.5-16k', 'longchat-13b-16k', 'Mistral-7B-v0.2', 'falcon-180B', 'Llama-3-70B', 'CodeLlama-34b-hf' with FSDP and FP8 we get KeyError: 'scaling_fwd'. This might be also issue with Transformer Engine,, so I'm happy to move this issue to TE if needed.

Full traceback:

[rank7]: Traceback (most recent call last): 7: [rank7]: File "/workspace/lightning-thunder/thunder/benchmarks/benchmark_litgpt.py", line 974, in 7: [rank7]: CLI(benchmark_main) 7: [rank7]: File "/usr/local/lib/python3.12/dist-packages/jsonargparse/_cli.py", line 96, in CLI 7: [rank7]: return _run_component(components, init) 7: [rank7]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 7: [rank7]: File "/usr/local/lib/python3.12/dist-packages/jsonargparse/_cli.py", line 204, in _run_component 7: [rank7]: return component(**cfg) 7: [rank7]: ^^^^^^^^^^^^^^^^ 7: [rank7]: File "/workspace/lightning-thunder/thunder/benchmarks/benchmark_litgpt.py", line 871, in benchmark_main 7: [rank7]: benchmark.train() 7: [rank7]: File "/workspace/lightning-thunder/thunder/benchmarks/benchmark_litgpt.py", line 765, in train 7: [rank7]: loss.backward() 7: [rank7]: File "/usr/local/lib/python3.12/dist-packages/torch/_tensor.py", line 624, in backward 7: [rank7]: torch.autograd.backward( 7: [rank7]: File "/usr/local/lib/python3.12/dist-packages/torch/autograd/init.py", line 347, in backward 7: [rank7]: _engine_run_backward( 7: [rank7]: File "/usr/local/lib/python3.12/dist-packages/torch/autograd/graph.py", line 825, in _engine_run_backward 7: [rank7]: return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass 7: [rank7]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 7: [rank7]: File "/usr/local/lib/python3.12/dist-packages/torch/autograd/function.py", line 307, in apply 7: [rank7]: return user_fn(self, *args) 7: [rank7]: ^^^^^^^^^^^^^^^^^^^^ 7: [rank7]: File "/usr/local/lib/python3.12/dist-packages/torch/autograd/function.py", line 600, in wrapper 7: [rank7]: outputs = fn(ctx, *args) 7: [rank7]: ^^^^^^^^^^^^^^ 7: [rank7]: File "/opt/pytorch/lightning-thunder/thunder/executors/torch_autograd.py", line 115, in backward 7: [rank7]: grads = ctx.compiled_backward([saved_tensors_list, ctx.saved_other], args) 7: [rank7]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 7: [rank7]: File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context 7: [rank7]: return func(*args, **kwargs) 7: [rank7]: ^^^^^^^^^^^^^^^^^^^^^ 7: [rank7]: File "/usr/local/lib/python3.12/dist-packages/torch/amp/autocast_mode.py", line 44, in decorate_autocast 7: [rank7]: return func(*args, **kwargs) 7: [rank7]: ^^^^^^^^^^^^^^^^^^^^^ 7: [rank7]: File "/usr/local/lib/python3.12/dist-packages/torch/amp/autocast_mode.py", line 44, in decorate_autocast 7: [rank7]: return func(*args, **kwargs) 7: [rank7]: ^^^^^^^^^^^^^^^^^^^^^ 7: [rank7]: File "thunder.backward_fn_13", line 28, in backward_fn 7: [rank7]: File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl 7: [rank7]: return self.call_impl(*args, **kwargs) 7: [rank7]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 7: [rank7]: File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1747, in call_impl 7: [rank7]: return forward_call(*args, **kwargs) 7: [rank7]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 7: [rank7]: File "/opt/pytorch/lightning-thunder/thunder/executors/transformer_engineex.py", line 205, in forward 7: [rank7]: weight_fp8, weight_t_fp8 = self.get_fp8_weight_version_compat( 7: [rank7]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 7: [rank7]: File "/opt/pytorch/lightning-thunder/thunder/executors/transformer_engineex.py", line 273, in get_fp8_weight_version_compat 7: [rank7]: weight_fp8 = self.get_fp8_workspace( 7: [rank7]: ^^^^^^^^^^^^^^^^^^^^^^^ 7: [rank7]: File "/usr/local/lib/python3.12/dist-packages/transformer_engine/pytorch/module/base.py", line 1086, in get_fp8_workspace 7: [rank7]: out.quantize(tensor, noop_flag=skip_update_flag) 7: [rank7]: File "/usr/local/lib/python3.12/dist-packages/transformer_engine/pytorch/tensor/float8_tensor.py", line 642, in quantize 7: [rank7]: fp8_meta = dst._fp8_meta[fp8_meta_key] 7: [rank7]: ~~~~~~~~~~~~~^^^^^^^^^^^^^^ 7: [rank7]: KeyError: 'scaling_fwd'

To Reproduce

Please use: 1 node(s), each with 8 GPUs. Image "INTERNAL_IMAGE:pjnl_20241107" Training script: python /opt/pytorch/lightning-thunder/thunder/benchmarks/benchmark_litgpt.py
--model_name Mistral-7B-v0.2
--distributed_mode fsdp
--shard_mode zero2
--compile dynamo_thunder
--checkpoint_activations True
--low_precision_mode fp8-delayed-te
--micro_batch_size 1

Environment

system.device_product_name DGXH100 system.gpu_driver_version 535.129.03 libraries.cuda 12.6.98.001 libraries.pip.lightning 2.4.0.dev20240728 libraries.pip.lightning-thunder 0.2.0.dev0 libraries.pip.lightning-utilities 0.11.8 libraries.pip.litgpt 0.4.11 libraries.pip.nvfuser 0.2.22+gitba4f7d4 libraries.pip.pytorch-lightning 2.4.0 libraries.pip.torch 2.6.0a0+gita9b4989 libraries.pip.torchao 0.6.1 libraries.pip.torchmetrics 1.5.1 libraries.pip.torchvision 0.19.0a0+d23a6e1

mpatel31415 avatar Nov 12 '24 09:11 mpatel31415

This seems to be happening due to interaction of TransformerEngine and checkpointing.

Minimal Repro

import torch
import torch.utils.checkpoint

def checkpointed_fn(x):
    y = x.cos()
    return torch.nn.functional.linear(x, y)

def fn(x):
    return torch.utils.checkpoint.checkpoint(checkpointed_fn, x)

from thunder.dynamo import ThunderCompiler
from thunder.executors.transformer_engineex import transformer_engine_ex
import thunder

backend = ThunderCompiler(executors=[transformer_engine_ex,])
x = torch.randn(16, 16, device='cuda', requires_grad=True)
o = torch.compile(fn, backend=backend)(x)

assert len(backend.subgraph_infos) == 1
subgraph_info = backend.subgraph_infos[0]
tfn = subgraph_info.thunder_compiled_fns[0]
print(thunder.last_traces(tfn)[-1])
print(thunder.last_backward_traces(tfn)[-1])

o.sum().backward()  # KeyError: 'scaling_fwd'

This happens because in the forward we are calling torch.nn.functional.linear but in the backward, we are calling te_functional_linear_backward (without ever calling the TE's forward).

Forward Graph

def computation(l_x_):
  # l_x_: "cuda:0 f32[16, 16]"
  t4 = torch.cos(l_x_)  # t4: "cuda:0 f32[16, 16]"
    # t4 = ltorch.cos(l_x_)  # t4: "cuda:0 f32[16, 16]"
      # t4 = prims.cos(l_x_)  # t4: "cuda:0 f32[16, 16]"
  getitem = torch.nn.functional.linear(l_x_, t4, None)  # getitem: "cuda:0 f32[16, 16]"
    # getitem = ltorch.linear(l_x_, t4, None)  # getitem: "cuda:0 f32[16, 16]"
      # getitem = prims.linear(l_x_, t4, None)  # getitem: "cuda:0 f32[16, 16]"
  del t4
  return {'output': getitem, 'flat_args': [l_x_], 'flat_output': (getitem,)}, ((l_x_,), ())

Backward Graph

def backward_fn(saved_for_backward, cotangents):
  # saved_for_backward: "Collection"
  # cotangents: "Collection"
  C0, _, = saved_for_backward
  clear_mutable_collection(saved_for_backward)
  del saved_for_backward
  t0, = cotangents
  clear_mutable_collection(cotangents)
  del cotangents
  l_x_, = C0
  clear_mutable_collection(C0)
  del C0
  t6 = torch.cos(l_x_)  # t6: "cuda:0 f32[16, 16]"
    # t6 = ltorch.cos(l_x_)  # t6: "cuda:0 f32[16, 16]"
      # t6 = prims.cos(l_x_)  # t6: "cuda:0 f32[16, 16]"
  (_, (t10, t11, t12, t13, t14, _), ctx_te_1) = te_linear_0(l_x_, t6, None)
  del t6
  (t19, t20, _) = te_functional_linear_backward((16, 16), (16, 16), None, ctx_te_1, (t10, t11, t12, t13, t14, None), t0)
  del ctx_te_1, t10, t11, t12, t13, t14, t0
  t21 = torch.sin(l_x_)  # t21: "cuda:0 f32[16, 16]"
    # t21 = ltorch.sin(l_x_)  # t21: "cuda:0 f32[16, 16]"
      # t21 = prims.sin(l_x_)  # t21: "cuda:0 f32[16, 16]"
  del l_x_
  t22 = torch.neg(t21)  # t22: "cuda:0 f32[16, 16]"
    # t22 = ltorch.neg(t21)  # t22: "cuda:0 f32[16, 16]"
      # t22 = prims.neg(t21)  # t22: "cuda:0 f32[16, 16]"
  del t21
  t23 = torch.mul(t20, t22)  # t23: "cuda:0 f32[16, 16]"
    # t23 = ltorch.mul(t20, t22)  # t23: "cuda:0 f32[16, 16]"
      # t23 = prims.mul(t20, t22)  # t23: "cuda:0 f32[16, 16]"
  del t20, t22
  t24 = torch.add(t19, t23)  # t24: "cuda:0 f32[16, 16]"
    # t24 = ltorch.add(t19, t23, alpha=1)  # t24: "cuda:0 f32[16, 16]"
      # t24 = prims.add(t19, t23)  # t24: "cuda:0 f32[16, 16]"
  del t19, t23
  te_sync_fp8_meta_bwd()
  return (t24,)

@kiya00 do you know why this could be happening? Thanks!

kshitij12345 avatar Nov 19 '24 13:11 kshitij12345

This happens because in the forward we are calling torch.nn.functional.linear but in the backward, we are calling te_functional_linear_backward (without ever calling the TE's forward).

https://github.com/Lightning-AI/lightning-thunder/blob/60f3ee1ec536ee8d6fdef503af54525e0a3978a4/thunder/torch/init.py#L5319-L5331 checkpointing uses vjp, is the te_linear_0 in the backward trace the original torch.nn.functional.linear ?

kiya00 avatar Nov 19 '24 13:11 kiya00

https://github.com/Lightning-AI/lightning-thunder/blob/60f3ee1ec536ee8d6fdef503af54525e0a3978a4/thunder/core/transforms.py#L2819 the input trace is:

@torch.no_grad()
@no_autocast
def flat_func(*flat_args):
  # flat_args: "Collection"
  t0, = flat_args
  t1 = ltorch.cos(t0)  # t1: "cuda:0 f32[16, 16]"
    # t1 = prims.cos(t0)  # t1: "cuda:0 f32[16, 16]"
  t2 = ltorch.linear(t0, t1, None)  # t2: "cuda:0 f32[16, 16]"
    # t2 = prims.linear(t0, t1, None)  # t2: "cuda:0 f32[16, 16]"
  return (t2,)

and after L2819, it seems the linear becomes te_linear_0 in the backward_fn:

@torch.no_grad()
@no_autocast
def backward_fn(saved_for_backward, cotangents):
  # saved_for_backward: "Collection"
  # cotangents: "Collection"
  C0, = saved_for_backward
  t0, = cotangents
  _torch_fx_graph_module_GraphModule___new___<locals>_GraphModuleImpl_0, C1, C2, \
  = C0
  l_x_, = C1
  # C2 (empty dict)
  t6 = prims.cos(l_x_)  # t6: "cuda:0 f32[16, 16]"
  t7 = ltorch.view(t0, (-1, 16))  # t7: "cuda:0 f32[16, 16]"
    # t7 = ltorch.reshape(t0, (-1, 16))  # t7: "cuda:0 f32[16, 16]"
      # t7 = prims.reshape(t0, (16, 16))  # t7: "cuda:0 f32[16, 16]"
  _ = ltorch.dim(t7)
  (_, _) = prims.shape(t7)
  (_, _) = prims.shape(t7)
  _ = ltorch.dim(t1)
  (_, _) = prims.shape(t1)
  (_, _) = prims.shape(t1)
  t8 = ltorch.view(a, (-1, 16))  # t8: "cuda:0 f32[16, 16]"
    # t8 = ltorch.reshape(a, (-1, 16))  # t8: "cuda:0 f32[16, 16]"
      # t8 = prims.reshape(a, (16, 16))  # t8: "cuda:0 f32[16, 16]"
  _ = ltorch.dim(t8)
  (_, _) = prims.shape(t8)
  (_, _) = prims.shape(t8)
  _ = ltorch.dim(w)
  (_, _) = prims.shape(w)
  (_, _) = prims.shape(w)
  t9 = ltorch.view(a, (-1, 16))  # t9: "cuda:0 f32[16, 16]"
    # t9 = ltorch.reshape(a, (-1, 16))  # t9: "cuda:0 f32[16, 16]"
      # t9 = prims.reshape(a, (16, 16))  # t9: "cuda:0 f32[16, 16]"
  _ = ltorch.dim(t9)
  (_, _) = prims.shape(t9)
  (_, _) = prims.shape(t9)
  _ = ltorch.dim(w)
  (_, _) = prims.shape(w)
  (_, _) = prims.shape(w)
  (t15, (t10, t11, t12, t13, t14, _), ctx_te_1) = te_linear_0(l_x_, t6, None)

kiya00 avatar Nov 19 '24 13:11 kiya00

TransformerEngine has its own checkpoint function. How is it different from PyTorch's checkpoint? Can TransformerEngine's checkpoint be used with the existing executor in Thunder to do activation checkpointing?

IvanYashchuk avatar Mar 11 '25 10:03 IvanYashchuk

Followup in tracking issue for functional TE executor #2027

riccardofelluga avatar May 05 '25 06:05 riccardofelluga

@riccardofelluga, is this really completed in #1908?

IvanYashchuk avatar May 22 '25 15:05 IvanYashchuk

@riccardofelluga, is this really completed in #1908?

No, let's keep it around till the new executor becomes default

riccardofelluga avatar May 22 '25 15:05 riccardofelluga

@riccardofelluga - when does the new executor becomes default? What PR should I look to track that?

nvMelissa avatar May 28 '25 19:05 nvMelissa

Verified solved after merging #2344.

Closing #2170 will close this issue.

riccardofelluga avatar Aug 08 '25 09:08 riccardofelluga

This now works on main thanks to #2510

riccardofelluga avatar Sep 11 '25 12:09 riccardofelluga