torchtitan icon indicating copy to clipboard operation
torchtitan copied to clipboard

Support Gemma2 in torchtitan

Open pansershrek opened this issue 1 year ago • 3 comments

Are there any plans to support Gemma2 in the torchtitan? I tried to use torchtitan to finetune Gemma2 model, but stuck on the following problem: how to parallelize tied layer in Gemma2 model? Maybe somebody kwon the solution for this problem 😄

pansershrek avatar Oct 01 '24 11:10 pansershrek

If you apply fully_shard to each transformer block and then to the root module, this should work for tied embedding and final linear. The root module will manage both.

awgu avatar Oct 01 '24 14:10 awgu

I want to shard output embedding layer - I use same strategy as in Llama, but training stacked after first butch ColwiseParallel( input_layouts=Shard(1), output_layouts=Shard(-1) if loss_parallel else Replicate(), use_local_output=not loss_parallel, )

pansershrek avatar Oct 01 '24 14:10 pansershrek

Do you want to train with 2D parallelism (FSDP + TP)? With TP only?

awgu avatar Oct 01 '24 15:10 awgu

@awgu Hi, when I tried to apply tied layers to embeddings and lm_head, it worked normally as I launch the job from scratch, but failed when trying to recover the interrupted job from the checkpoint and load the optimizer state dict: KeyError: 'state.lm_head.weight.step.

Everything is ok when removing tied layers.

Do you have any clues for solving this error. Thank you.

yzhangcs avatar Jan 05 '25 20:01 yzhangcs

@yzhangcs sorry I am not as familiar with the checkpointing part. @fegin can you give some guidance here? Should the DCP implementation in torchtitan support parameter sharing?

awgu avatar Jan 05 '25 20:01 awgu

@tianyu-l Hi, curious if the problem has been solved by https://github.com/pytorch/pytorch/pull/128076. Since I'm using hf-style models, not sure if this error is caused by the way of how hf implements tied weights.

yzhangcs avatar Jan 07 '25 21:01 yzhangcs

Missing optimizer state for the tied weights should already be fixed a while ago, https://github.com/pytorch/pytorch/pull/128685. Can you point out which PyTorch version you use? @yzhangcs

Updated: I checked the fix is in 2.4. So I'm sure that your PyTorch is not going to be older than 2.4. Would you be able to provide some repro?

fegin avatar Jan 07 '25 21:01 fegin

@fegin

Missing optimizer state for the tied weights should already be fixed a while ago

Thank you for pointing this out. I'm using 2.5.1. It looks like its a problem of tied layer definition. I'll check it again and let you know if having any clues.

yzhangcs avatar Jan 08 '25 04:01 yzhangcs

@fegin I've identified that the issue stems from flatten_optimizer_state_dict, which is set to True by default.

It appears that PyTorch doesn't properly handle tied layers in _unflatten_optim_state_dict. https://github.com/pytorch/pytorch/blob/main/torch/distributed/checkpoint/state_dict.py#L898-L905 I've created a minimal test case to reproduce the error:

import copy
from typing import Callable

import torch
import torch.nn as nn
from torch.distributed._tensor import init_device_mesh
from torch.distributed.checkpoint.state_dict import (StateDictOptions,
                                                     get_state_dict,
                                                     set_model_state_dict,
                                                     set_optimizer_state_dict)
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import (
    DTensorTestBase, with_comms)
from torch.testing._internal.distributed.common_state_dict import \
    VerifyStateDictMixin


class TiedEmbeddingModel(nn.Module):
    def __init__(self, vocab_size, embedding_dim):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.decoder = nn.Linear(embedding_dim, vocab_size)
        self.decoder.weight = self.embedding.weight  # Tying weights

    def forward(self, input):
        input = (input * 10).to(torch.int)
        embedded = self.embedding(input)
        output = self.decoder(embedded)
        return output


class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
    """Tests state_dict and load_state_dict"""

    @property
    def world_size(self) -> int:
        return min(4, torch.cuda.device_count())

    def _test_save_load(
        self,
        init_model_optim: Callable,
        test_frozen: bool = False,
        flatten_optimizer_state_dict: bool = False
    ) -> None:
        options = StateDictOptions(ignore_frozen_params=test_frozen,
                                   flatten_optimizer_state_dict=flatten_optimizer_state_dict)
        # Initialize original model and distributed model.
        model, optim, copy_optim, dist_model, dist_optim = init_model_optim()

        # Train 10 steps.
        _dist_optim = [dist_optim] if not isinstance(
            dist_optim, list) else dist_optim
        for _ in range(10):
            optim.zero_grad()
            for d_optim in _dist_optim:
                d_optim.zero_grad()

            batch = torch.rand(8, 100, device="cuda")
            model(batch).sum().backward()
            dist_model(batch).sum().backward()

            optim.step()
            for d_optim in _dist_optim:
                d_optim.step()

        # Get the state_dict, and compare the result
        msd = model.state_dict()
        osd = optim.state_dict()
        dist_msd, dist_osd = get_state_dict(
            dist_model, optimizers=dist_optim, options=options
        )
        self._verify_msd(msd, dist_msd, options)
        self._verify_osd_by_load(model, optim, copy_optim, dist_osd)
        self._verify_osd(model, optim, osd, dist_osd)

        # Initialize a completely new model to simulate checkpoint load.
        _, _, _, dist_model, dist_optim = init_model_optim()

        # Simulate DCP distributed load. We need to first get the state_dict and
        # pass them to DCP to load the saved state_dict from the storage.
        # Then finally we can call set_state_dict().
        if not isinstance(dist_optim, list):
            dist_optim = [dist_optim]
        if test_frozen:
            # We won't be able to load the partial state_dict back.
            return
        # Since we already have the state_dict saved before, no need to call DCP.
        # We can directly load them back. This asser is to ensure that optimizer
        # state storage are initialized.
        # self.assertEqual(len(curr_dist_osd[STATE]), len(dist_osd[STATE]))
        set_model_state_dict(
            dist_model,
            model_state_dict=dist_msd,
            options=options,
        )
        set_optimizer_state_dict(
            dist_model,
            optimizers=dist_optim,
            optim_state_dict=dist_osd,
            options=options,
        )

    @with_comms
    @skip_if_lt_x_gpu(2)
    def test_shared_weight(self):
        def init_model_optim():
            device_mesh = init_device_mesh("cuda", (self.world_size,))
            orig_model = TiedEmbeddingModel(32000, 1024).to(torch.device("cuda"))
            orig_optim = torch.optim.AdamW(orig_model.parameters(), lr=1e-4)
            copy_optim = torch.optim.AdamW(orig_model.parameters(), lr=1e-4)

            copy_model = copy.deepcopy(orig_model)
            dist_model = FSDP(copy_model, device_mesh=device_mesh)
            dist_optim = torch.optim.AdamW(dist_model.parameters(), lr=1e-4)
            return orig_model, orig_optim, copy_optim, dist_model, dist_optim
        self._test_save_load(init_model_optim, flatten_optimizer_state_dict=False)

    @with_comms
    @skip_if_lt_x_gpu(2)
    def test_shared_weight_flatten(self):
        def init_model_optim():
            device_mesh = init_device_mesh("cuda", (self.world_size,))
            orig_model = TiedEmbeddingModel(32000, 1024).to(torch.device("cuda"))
            orig_optim = torch.optim.AdamW(orig_model.parameters(), lr=1e-4)
            copy_optim = torch.optim.AdamW(orig_model.parameters(), lr=1e-4)

            copy_model = copy.deepcopy(orig_model)
            dist_model = FSDP(copy_model, device_mesh=device_mesh)
            dist_optim = torch.optim.AdamW(dist_model.parameters(), lr=1e-4)
            return orig_model, orig_optim, copy_optim, dist_model, dist_optim
        self._test_save_load(init_model_optim, flatten_optimizer_state_dict=True)


if __name__ == "__main__":
    run_tests()

The output is

...
KeyError: 'param_groups.decoder.weight.lr'
...
========================================================================== short test summary info ===========================================================================
FAILED test_state_dict.py::TestStateDict::test_shared_weight_flatten - RuntimeError: Process 1 exited with error code 10 and exception:
============================================================= 1 failed, 1 passed, 1 warning in 128.12s (0:02:08) =============================================================

As you can see, flatten_optimizer_state_dict=False works correctly, but when flatten_optimizer_state_dict=True, the KeyError occurs. To work around this issue, I've disabled the option in: https://github.com/pytorch/torchtitan/blob/90567fc9827ffdf17bdd0349cd5276c662d0769a/torchtitan/optimizer.py#L59 However, I have two concerns:

  • I'm wondering if disabling this option might significantly impact performance, especially w/o PP.
  • It would be great if PyTorch could provide full support for tied layers in the official code.

Looking forward to your feedbacks. Thank you.

cc @tianyu-l @awgu

yzhangcs avatar Jan 09 '25 13:01 yzhangcs

  • I'm wondering if disabling this option might significantly impact performance, especially w/o PP.

No, it won't

  • It would be great if PyTorch could provide full support for tied layers in the official code.

Yes, since this is a bug, I'll take a look at this issue. Thanks!

fegin avatar Jan 09 '25 17:01 fegin

@fegin Hello, wondering if this bug has been fixed in torch?

yzhangcs avatar Mar 04 '25 18:03 yzhangcs

@yzhangcs https://github.com/pytorch/pytorch/pull/148825 fixes the issue.

fegin avatar Mar 08 '25 20:03 fegin

Thank you!

yzhangcs avatar Mar 09 '25 07:03 yzhangcs

@yzhangcs The PR is landed. You should be able to get it with the next nightly built PyTorch. Please let me know if that completely resolve the issue. I can only verify through UT.

fegin avatar Mar 10 '25 20:03 fegin

@yzhangcs I'll close the issue. Let me know if you still encounter issues.

fegin avatar Mar 20 '25 18:03 fegin