Support Gemma2 in torchtitan
Are there any plans to support Gemma2 in the torchtitan? I tried to use torchtitan to finetune Gemma2 model, but stuck on the following problem: how to parallelize tied layer in Gemma2 model? Maybe somebody kwon the solution for this problem 😄
If you apply fully_shard to each transformer block and then to the root module, this should work for tied embedding and final linear. The root module will manage both.
I want to shard output embedding layer - I use same strategy as in Llama, but training stacked after first butch
ColwiseParallel( input_layouts=Shard(1), output_layouts=Shard(-1) if loss_parallel else Replicate(), use_local_output=not loss_parallel, )
Do you want to train with 2D parallelism (FSDP + TP)? With TP only?
@awgu Hi, when I tried to apply tied layers to embeddings and lm_head, it worked normally as I launch the job from scratch, but failed when trying to recover the interrupted job from the checkpoint and load the optimizer state dict:
KeyError: 'state.lm_head.weight.step.
Everything is ok when removing tied layers.
Do you have any clues for solving this error. Thank you.
@yzhangcs sorry I am not as familiar with the checkpointing part. @fegin can you give some guidance here? Should the DCP implementation in torchtitan support parameter sharing?
@tianyu-l Hi, curious if the problem has been solved by https://github.com/pytorch/pytorch/pull/128076. Since I'm using hf-style models, not sure if this error is caused by the way of how hf implements tied weights.
Missing optimizer state for the tied weights should already be fixed a while ago, https://github.com/pytorch/pytorch/pull/128685. Can you point out which PyTorch version you use? @yzhangcs
Updated: I checked the fix is in 2.4. So I'm sure that your PyTorch is not going to be older than 2.4. Would you be able to provide some repro?
@fegin
Missing optimizer state for the tied weights should already be fixed a while ago
Thank you for pointing this out. I'm using 2.5.1. It looks like its a problem of tied layer definition. I'll check it again and let you know if having any clues.
@fegin I've identified that the issue stems from flatten_optimizer_state_dict, which is set to True by default.
It appears that PyTorch doesn't properly handle tied layers in _unflatten_optim_state_dict.
https://github.com/pytorch/pytorch/blob/main/torch/distributed/checkpoint/state_dict.py#L898-L905
I've created a minimal test case to reproduce the error:
import copy
from typing import Callable
import torch
import torch.nn as nn
from torch.distributed._tensor import init_device_mesh
from torch.distributed.checkpoint.state_dict import (StateDictOptions,
get_state_dict,
set_model_state_dict,
set_optimizer_state_dict)
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase, with_comms)
from torch.testing._internal.distributed.common_state_dict import \
VerifyStateDictMixin
class TiedEmbeddingModel(nn.Module):
def __init__(self, vocab_size, embedding_dim):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embedding_dim)
self.decoder = nn.Linear(embedding_dim, vocab_size)
self.decoder.weight = self.embedding.weight # Tying weights
def forward(self, input):
input = (input * 10).to(torch.int)
embedded = self.embedding(input)
output = self.decoder(embedded)
return output
class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
"""Tests state_dict and load_state_dict"""
@property
def world_size(self) -> int:
return min(4, torch.cuda.device_count())
def _test_save_load(
self,
init_model_optim: Callable,
test_frozen: bool = False,
flatten_optimizer_state_dict: bool = False
) -> None:
options = StateDictOptions(ignore_frozen_params=test_frozen,
flatten_optimizer_state_dict=flatten_optimizer_state_dict)
# Initialize original model and distributed model.
model, optim, copy_optim, dist_model, dist_optim = init_model_optim()
# Train 10 steps.
_dist_optim = [dist_optim] if not isinstance(
dist_optim, list) else dist_optim
for _ in range(10):
optim.zero_grad()
for d_optim in _dist_optim:
d_optim.zero_grad()
batch = torch.rand(8, 100, device="cuda")
model(batch).sum().backward()
dist_model(batch).sum().backward()
optim.step()
for d_optim in _dist_optim:
d_optim.step()
# Get the state_dict, and compare the result
msd = model.state_dict()
osd = optim.state_dict()
dist_msd, dist_osd = get_state_dict(
dist_model, optimizers=dist_optim, options=options
)
self._verify_msd(msd, dist_msd, options)
self._verify_osd_by_load(model, optim, copy_optim, dist_osd)
self._verify_osd(model, optim, osd, dist_osd)
# Initialize a completely new model to simulate checkpoint load.
_, _, _, dist_model, dist_optim = init_model_optim()
# Simulate DCP distributed load. We need to first get the state_dict and
# pass them to DCP to load the saved state_dict from the storage.
# Then finally we can call set_state_dict().
if not isinstance(dist_optim, list):
dist_optim = [dist_optim]
if test_frozen:
# We won't be able to load the partial state_dict back.
return
# Since we already have the state_dict saved before, no need to call DCP.
# We can directly load them back. This asser is to ensure that optimizer
# state storage are initialized.
# self.assertEqual(len(curr_dist_osd[STATE]), len(dist_osd[STATE]))
set_model_state_dict(
dist_model,
model_state_dict=dist_msd,
options=options,
)
set_optimizer_state_dict(
dist_model,
optimizers=dist_optim,
optim_state_dict=dist_osd,
options=options,
)
@with_comms
@skip_if_lt_x_gpu(2)
def test_shared_weight(self):
def init_model_optim():
device_mesh = init_device_mesh("cuda", (self.world_size,))
orig_model = TiedEmbeddingModel(32000, 1024).to(torch.device("cuda"))
orig_optim = torch.optim.AdamW(orig_model.parameters(), lr=1e-4)
copy_optim = torch.optim.AdamW(orig_model.parameters(), lr=1e-4)
copy_model = copy.deepcopy(orig_model)
dist_model = FSDP(copy_model, device_mesh=device_mesh)
dist_optim = torch.optim.AdamW(dist_model.parameters(), lr=1e-4)
return orig_model, orig_optim, copy_optim, dist_model, dist_optim
self._test_save_load(init_model_optim, flatten_optimizer_state_dict=False)
@with_comms
@skip_if_lt_x_gpu(2)
def test_shared_weight_flatten(self):
def init_model_optim():
device_mesh = init_device_mesh("cuda", (self.world_size,))
orig_model = TiedEmbeddingModel(32000, 1024).to(torch.device("cuda"))
orig_optim = torch.optim.AdamW(orig_model.parameters(), lr=1e-4)
copy_optim = torch.optim.AdamW(orig_model.parameters(), lr=1e-4)
copy_model = copy.deepcopy(orig_model)
dist_model = FSDP(copy_model, device_mesh=device_mesh)
dist_optim = torch.optim.AdamW(dist_model.parameters(), lr=1e-4)
return orig_model, orig_optim, copy_optim, dist_model, dist_optim
self._test_save_load(init_model_optim, flatten_optimizer_state_dict=True)
if __name__ == "__main__":
run_tests()
The output is
...
KeyError: 'param_groups.decoder.weight.lr'
...
========================================================================== short test summary info ===========================================================================
FAILED test_state_dict.py::TestStateDict::test_shared_weight_flatten - RuntimeError: Process 1 exited with error code 10 and exception:
============================================================= 1 failed, 1 passed, 1 warning in 128.12s (0:02:08) =============================================================
As you can see, flatten_optimizer_state_dict=False works correctly, but when flatten_optimizer_state_dict=True, the KeyError occurs.
To work around this issue, I've disabled the option in:
https://github.com/pytorch/torchtitan/blob/90567fc9827ffdf17bdd0349cd5276c662d0769a/torchtitan/optimizer.py#L59
However, I have two concerns:
- I'm wondering if disabling this option might significantly impact performance, especially w/o PP.
- It would be great if PyTorch could provide full support for tied layers in the official code.
Looking forward to your feedbacks. Thank you.
cc @tianyu-l @awgu
- I'm wondering if disabling this option might significantly impact performance, especially w/o PP.
No, it won't
- It would be great if PyTorch could provide full support for tied layers in the official code.
Yes, since this is a bug, I'll take a look at this issue. Thanks!
@fegin Hello, wondering if this bug has been fixed in torch?
@yzhangcs https://github.com/pytorch/pytorch/pull/148825 fixes the issue.
Thank you!
@yzhangcs The PR is landed. You should be able to get it with the next nightly built PyTorch. Please let me know if that completely resolve the issue. I can only verify through UT.
@yzhangcs I'll close the issue. Let me know if you still encounter issues.