Models trained with FSDP + Thunder doesn't work with litgpt chat
I was able to train Llama3-8b model with Thunder for a few steps and then save it. However when I try to use later litgpt generate or litgpt chat with the saved checkpoint I get an error about size mismatch. When I run the training in Eager mode everything works.
🐛 Bug
To Reproduce
-
Please extract this archive and put all the files into selected directory (let's call it CHECKPOINT_DIR) Meta-Llama-3-8B-tuned.zip . Here is the license.
These are Llama-3B configuration files (no weights), they can be also downloaded by running:
litgpt download meta-llama/Meta-Llama-3-8B -
Copy the benchmarking script from this repo located here
thunder/benchmarks/benchmark_litgpt.pyand add model saving in line 622:
torch_dist.barrier()
states = benchmark.model.state_dict()
if global_rank == 0:
torch.save(states, "/lightning-thunder/checkpoints/meta-llama/Meta-Llama-3-8B-tuned/lit_model.pth")
To be sure that version of the script is the same, I'm also attaching the full, modified file (it's python code, but I can add only txt files here): benchmark_litgpt.txt
Let's assume it's located in SCRIPT_DIR directory.
- Start docker container on a node with 8xH100:
docker run --pull=always --gpus all --ipc=host --ulimit \
memlock=-1 --ulimit stack=67108864 -it \
-v ${CHECKPOINT_DIR}:/lightning-thunder/checkpoints/meta-llama/Meta-Llama-3-8B-tuned \
-v ${SCRIPT_DIR}:/repro
INTERNAL_IMAGE:nvidia internal container from 20240731
- Install recent litgpt version:
python -m pip install litgpt==0.4.5
For Eager
5E. Run training for Eager (on dummy data so output won't make sense, but it's easier to run the reproduction instructions)
torchrun --standalone --max-restarts=0 --nproc-per-node=8 /repro/benchmark_litgpt.py --model_name Llama-3-8B --max_iters 10 --warmup_iters 2 --distributed_mode fsdp --shard_mode zero3 --bucketing_mode block
You should see new file lit_model.pth in checkpoint directory.
6E. Try to chat with the saved model:
litgpt chat /lightning-thunder/checkpoints/meta-llama/Meta-Llama-3-8B-tuned
It should run but return garbage.
For Thunder
5T. You can remove the lit_model.pth (but it will be overwritten anyway) and then run:
torchrun --standalone --max-restarts=0 --nproc-per-node=8 /repro/benchmark_litgpt.py --model_name Llama-3-8B --max_iters 10 --warmup_iters 2 --distributed_mode fsdp --shard_mode zero3 --bucketing_mode block --compile thunder
6T. Try to chat with the saved model:
litgpt chat /lightning-thunder/checkpoints/meta-llama/Meta-Llama-3-8B-tuned
There is an error:
{'access_token': None, 'checkpoint_dir': PosixPath('/lightning-thunder/checkpoints/meta-llama/Meta-Llama-3-8B-tuned'), 'compile': False, 'max_new_tokens': 50, 'multiline': False, 'precision': None, 'quantize': None, 'temperature': 0.8, 'top_k': 200, 'top_p': 1.0} Traceback (most recent call last): File "/usr/local/bin/litgpt", line 8, in
sys.exit(main()) File "/usr/local/lib/python3.10/dist-packages/litgpt/main.py", line 71, in main CLI(parser_data) File "/usr/local/lib/python3.10/dist-packages/jsonargparse/_cli.py", line 119, in CLI return _run_component(component, init.get(subcommand)) File "/usr/local/lib/python3.10/dist-packages/jsonargparse/_cli.py", line 204, in _run_component return component(**cfg) File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context return func(*args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/litgpt/chat/base.py", line 258, in main load_checkpoint(fabric, model, checkpoint_path) File "/usr/local/lib/python3.10/dist-packages/litgpt/utils.py", line 362, in load_checkpoint model.load_state_dict(state_dict, strict=strict) File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 2542, in load_state_dict raise RuntimeError( RuntimeError: Error(s) in loading state_dict for GPT: size mismatch for lm_head.weight: copying a param with shape torch.Size([16032, 4096]) from checkpoint, the shape in current model is torch.Size([128256, 4096]). size mismatch for transformer.wte.weight: copying a param with shape torch.Size([16032, 4096]) from checkpoint, the shape in current model is torch.Size([128256, 4096]). size mismatch for transformer.h.0.norm_1.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([4096]). ...
Expected behavior
We should be able to run model trained with Thunder with litgpt instructions.
Environment
nvidia-smi output:
Version of packages:
lightning-thunder 0.2.0.dev0 /opt/pytorch/lightning-thunder lightning-utilities 0.11.6 litgpt 0.4.5 nvfuser 0.2.8+gitaf62096 /opt/pytorch/nvfuser pytorch-lightning 2.3.3 torch 2.5.0a0+git83db609 torchmetrics 1.4.0.post0 torchvision 0.19.0a0+d23a6e1
https://github.com/Lightning-AI/lightning-thunder/issues/564 could be related
FYI: I was curious if the code to save checkpoint is correct in Eager mode for sure, so I used it on each rank and then compared the shapes of parameters from state_dict with the original (lit_model.pth) model, before it was wrapped with FSDP and values between ranks (to check if they were synchronized) . And it seems that both shapes and values are equal.
Small update after discussion with @carmocca about saving checkpoints from Thunder FSDP:
I tried to use save and get_model_state_dict functions provided by Thunder and then convert checkpoint into torch save checkpoint using dcp_to_torch_save, but I also get shape error when later trying to use the output with litgpt chat.
Below is the code I used (I should be possible to copy it instead of the code provided in the original description):
from thunder.distributed.checkpoint import save, get_model_state_dict, StateDictOptions
from torch.distributed.checkpoint.format_utils import dcp_to_torch_save
options = StateDictOptions(full_state_dict=False, cpu_offload=False)
state_dict = get_model_state_dict(model, options, rank)
dcp_path = "/lightning-thunder/checkpoints/meta-llama/Meta-Llama-3-8B-tuned/distributed_ckp"
save(state_dict, dcp_path)
torch_dist.barrier()
if rank == 0:
dcp_to_torch_save(dcp_path, "/lightning-thunder/checkpoints/meta-llama/Meta-Llama-3-8B-tuned/lit_model.pth")
The only option that could make it work now is to train the model with Fabric FSDP, but I haven't tested it yet.
triage review:
- we actually talked on slack about this
- PyTorch might have a not great example here that led us astray. This is related to how to save checkpoints.
- need some more thought on this one; unclear what's at fault and why.
Hi! Is there any update about this? From the Slack discussion and my understanding there were 3 options for me to progress:
- Save distributed checkpoint in Thunder, convert it to "torch save" checkpoint using Pytorch function (it should be possible because the distributed Thunder checkpoint is expected to be equivalent to distributed Pytorch checkpoint) and use the resulting "torch save" checkpoint in LitGPT chat. This option has also failed. Should I wait for it to be resolved?
- Train with Fabric and everything should work, but do we expect that Thunder will work with LitGPT only when Fabric is used? If so or this is the best solution for now I can change the code of the demo to use Fabric.
- Write my own "chat" script loading distributed Thunder checkpoint.
Please let me know which direction is the best to follow from your perspective.