transformers icon indicating copy to clipboard operation
transformers copied to clipboard

fix: extend the unwrap_model function and save unwrapped model state dict instead of wrapped

Open shub-kris opened this issue 11 months ago • 3 comments

What does this PR do?

This PR pushes two changes:

  • Save the unwrap_model(model).state_dict() wheneverif isinstance(unwrap_model(model), supported_classes)
  • Extend the unwrap_model() so that any wrapper on the children layer of model can also be unwrapped correctly.

With the existing unwrap_model() only the outermost layer is unwrapped and it fails when we use wrapping with fsdp as it doesn't go through the children layers or modules.

For example:

A Wrapped Model

SpmdFullyShardedDataParallel(
  (_orig_module): GemmaForCausalLM(
    (model): GemmaModel(
      (embed_tokens): Embedding(256000, 2048, padding_idx=0)
      (layers): ModuleList(
        (0-17): 18 x SpmdFullyShardedDataParallel(
          (_orig_module): GemmaDecoderLayer(
            (self_attn): GemmaAttention(
              (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
              (k_proj): Linear(in_features=2048, out_features=256, bias=False)
              (v_proj): Linear(in_features=2048, out_features=256, bias=False)
              (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
              (rotary_emb): GemmaRotaryEmbedding()
            )
            (mlp): GemmaMLP(
              (gate_proj): Linear(in_features=2048, out_features=16384, bias=False)
              (up_proj): Linear(in_features=2048, out_features=16384, bias=False)
              (down_proj): Linear(in_features=16384, out_features=2048, bias=False)
              (act_fn): PytorchGELUTanh()
            )
            (input_layernorm): GemmaRMSNorm()
            (post_attention_layernorm): GemmaRMSNorm()
          )
        )
      )
      (norm): GemmaRMSNorm()
    )
    (lm_head): Linear(in_features=2048, out_features=256000, bias=False)
  )
)

When unwrapped using existing unwrap_model() leads to

GemmaForCausalLM(
  (model): GemmaModel(
    (embed_tokens): Embedding(256000, 2048, padding_idx=0)
    (layers): ModuleList(
      (0-17): 18 x SpmdFullyShardedDataParallel(
        (_orig_module): GemmaDecoderLayer(
          (self_attn): GemmaAttention(
            (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
            (k_proj): Linear(in_features=2048, out_features=256, bias=False)
            (v_proj): Linear(in_features=2048, out_features=256, bias=False)
            (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
            (rotary_emb): GemmaRotaryEmbedding()
          )
          (mlp): GemmaMLP(
            (gate_proj): Linear(in_features=2048, out_features=16384, bias=False)
            (up_proj): Linear(in_features=2048, out_features=16384, bias=False)
            (down_proj): Linear(in_features=16384, out_features=2048, bias=False)
            (act_fn): PytorchGELUTanh()
          )
          (input_layernorm): GemmaRMSNorm()
          (post_attention_layernorm): GemmaRMSNorm()
        )
      )
    )
    (norm): GemmaRMSNorm()
  )
  (lm_head): Linear(in_features=2048, out_features=256000, bias=False)
)

But when using the change mentioned in this repo:

GemmaForCausalLM(
  (model): GemmaModel(
    (embed_tokens): Embedding(256000, 2048, padding_idx=0)
    (layers): ModuleList(
      (0-17): 18 x GemmaDecoderLayer(
        (self_attn): GemmaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=256, bias=False)
          (v_proj): Linear(in_features=2048, out_features=256, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (rotary_emb): GemmaRotaryEmbedding()
        )
        (mlp): GemmaMLP(
          (gate_proj): Linear(in_features=2048, out_features=16384, bias=False)
          (up_proj): Linear(in_features=2048, out_features=16384, bias=False)
          (down_proj): Linear(in_features=16384, out_features=2048, bias=False)
          (act_fn): PytorchGELUTanh()
        )
        (input_layernorm): GemmaRMSNorm()
        (post_attention_layernorm): GemmaRMSNorm()
      )
    )
    (norm): GemmaRMSNorm()
  )
  (lm_head): Linear(in_features=2048, out_features=256000, bias=False)
)


Fixes #29659

Before submitting

  • [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • [ ] Did you read the contributor guideline, Pull Request section?
  • [ ] Was this discussed/approved via a Github issue or the forum? Please add a link to it if that's the case.
  • [ ] Did you make sure to update the documentation with your changes? Here are the documentation guidelines, and here are tips on formatting docstrings.
  • [ ] Did you write any new necessary tests?

Who can review?

@amyeroberts @muellerzr @pacman100 Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR.

shub-kris avatar Mar 21 '24 13:03 shub-kris

@alanwaketan can you also take a look please ?

shub-kris avatar Mar 21 '24 13:03 shub-kris

You can replicate the wrapping and unwrapping using this script:

import torch
import torch_xla
import torch.nn as nn
import functools
from transformers import AutoModelForCausalLM
from torch_xla.experimental.spmd_fully_sharded_data_parallel import SpmdFullyShardedDataParallel as FSDPv2
from torch_xla.distributed.fsdp.wrap import (
                    transformer_auto_wrap_policy,
                )
import torch_xla.distributed.spmd as xs
import torch_xla.runtime as xr
import numpy as np
from torch_xla.distributed.fsdp import checkpoint_module
from transformers.trainer_pt_utils import get_module_class_from_name
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.modeling_utils import unwrap_model

def wrap_model(model, fsdp_config):
    num_devices = xr.global_runtime_device_count()
    xs.set_global_mesh(xs.Mesh(np.array(range(num_devices)), (num_devices, 1), axis_names=("fsdp", "tensor")))
    
    auto_wrap_policy = None
    auto_wrapper_callable = None
    default_transformer_cls_names_to_wrap = getattr(model, "_no_split_modules", None)
    fsdp_transformer_layer_cls_to_wrap = fsdp_config.get(
        "transformer_layer_cls_to_wrap", default_transformer_cls_names_to_wrap
    )

    if fsdp_transformer_layer_cls_to_wrap is not None:
        transformer_cls_to_wrap = set()
        for layer_class in fsdp_transformer_layer_cls_to_wrap:
            print(f"layer class is {layer_class}")
            transformer_cls = get_module_class_from_name(model, layer_class)
            if transformer_cls is None:
                raise Exception("Could not find the transformer layer class to wrap in the model.")
            else:
                transformer_cls_to_wrap.add(transformer_cls)
        print(f"transformer_cls_to_wrap: {transformer_cls_to_wrap}")
        auto_wrap_policy = functools.partial(
            transformer_auto_wrap_policy,
            # Transformer layer class to wrap
            transformer_layer_cls=transformer_cls_to_wrap,
        )
        if fsdp_config["xla_fsdp_grad_ckpt"]:
            # Apply gradient checkpointing to auto-wrapped sub-modules if specified
            def auto_wrapper_callable(m, *args, **kwargs):
                target_cls = FSDPv2
                return target_cls(checkpoint_module(m), *args, **kwargs)


            def shard_output(output, mesh):
                real_output = None
                if isinstance(output, torch.Tensor):
                    real_output = output
                elif isinstance(output, tuple):
                    real_output = output[0]
                elif isinstance(output, CausalLMOutputWithPast):
                    real_output = output.logits

                if real_output is None:
                    raise ValueError("Something went wrong, the output of the model shouldn't be `None`")
                xs.mark_sharding(real_output, mesh, ("fsdp", None, None))
            
            print(f"auto wrap policy is {auto_wrap_policy}")
            print(f"auto wrapper callable is {auto_wrapper_callable}")
            model = FSDPv2(
                model,
                shard_output=shard_output,
                auto_wrap_policy=auto_wrap_policy,
                auto_wrapper_callable=auto_wrapper_callable,
        )
        return model
    


def unwrap_model_new(model: nn.Module) -> nn.Module:
    """
    Recursively unwraps a module and its child sublayers.

    Args:
        model (nn.Module): Module to unwrap.

    Returns:
        nn.Module: The unwrapped module.
    """

    def recursive_unwrap(module):
        if hasattr(module, "module"):
            try:
                unwrapped_module = recursive_unwrap(getattr(module, "module"))
            except AttributeError:
                unwrapped_module = module  # Handle cases where wrapped module is inaccessible
            return unwrapped_module

        # Unwrap child sublayers recursively
        for name, child in module.named_children():
            setattr(module, name, recursive_unwrap(child))

        return module

    # Start with top-level unwrapping
    unwrapped_model = recursive_unwrap(model)
    return unwrapped_model

def main():
    model_id = "google/gemma-2b"
    # Load model and tokenizer
    model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16)
    
    fsdp_config = {
        "fsdp_transformer_layer_cls_to_wrap": ["GemmaDecoderLayer"],
        "xla": True,
        "xla_fsdp_v2": True,
        "xla_fsdp_grad_ckpt": True,
    }
    wrapped_model = wrap_model(model, fsdp_config)
    print(wrapped_model)
    
    unwrapped_model_old = unwrap_model(wrapped_model)
    print(unwrapped_model_old)
    
    unwrapped_model_new = unwrap_model_new(wrapped_model)
    print(unwrapped_model_new)
    
if __name__ == "__main__":
    main()

shub-kris avatar Mar 21 '24 13:03 shub-kris

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@amyeroberts I had to change the unwrap_model because of the changes introduced here: #28949 which was Support PyTorch/XLA FSDP via SPMD and the existing unwrap_model only fails there. I can write a test, but the problem is it requires TPU and I am not sure if we have that as a part of our CI runner?

So, how should we proceed here?

shub-kris avatar Mar 26 '24 17:03 shub-kris

@amyeroberts here is a small snippet for the test:

import torch
import torch_xla
import torch.nn as nn
from transformers import AutoModelForCausalLM
from torch_xla.experimental.spmd_fully_sharded_data_parallel import SpmdFullyShardedDataParallel as FSDPv2
import torch_xla.distributed.spmd as xs
import torch_xla.runtime as xr
import numpy as np
import unittest

def compare_state_dict_keys(state_dict_keys_model1, state_dict_keys_model2):
    for key1, key2 in zip(state_dict_keys_model1, state_dict_keys_model2):
        if key1 != key2:
            # print(f"Keys are not equal")
            # print(key1, key2)
            return False
    return True

# Original `unwrap_model` function
def original_unwrap_model(model: nn.Module) -> nn.Module:
    """Original unwrap implementation for comparison."""
    if hasattr(model, "module"):
        return original_unwrap_model(model.module)
    else:
        return model

def unwrap_model_new(model: nn.Module) -> nn.Module:
    """
    Recursively unwraps a module and its child sublayers.

    Args:
        model (nn.Module): Module to unwrap.

    Returns:
        nn.Module: The unwrapped module.
    """

    def recursive_unwrap(module):
        if hasattr(module, "module"):
            unwrapped_module = recursive_unwrap(getattr(module, "module"))
        else:
            unwrapped_module = module  # Handle cases where wrapped module is inaccessible

        # Unwrap child sublayers recursively
        for name, child in module.named_children():
            setattr(module, name, recursive_unwrap(child))

        return unwrapped_module

    # Start with top-level unwrapping
    unwrapped_model = recursive_unwrap(model)
    return unwrapped_model

class TestUnwrap(unittest.TestCase):    
    def test_compatibility_with_original_behavior(self):
        model_id = "mistralai/Mistral-7B-v0.1"
        model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16)
        num_devices = xr.global_runtime_device_count()
        xs.set_global_mesh(xs.Mesh(np.array(range(num_devices)), (num_devices, 1), axis_names=("fsdp", "tensor")))
        
        wrapped_model = FSDPv2(model)
        unwrapped_model_old = original_unwrap_model(wrapped_model)
        state_dict_keys_model1 = list(unwrapped_model_old.state_dict().keys())
        unwrapped_model_new = unwrap_model_new(wrapped_model)
        state_dict_keys_model2 = list(unwrapped_model_new.state_dict().keys())

        assert compare_state_dict_keys(state_dict_keys_model1, state_dict_keys_model2) == True
        
    def test_nested_unwrap_modules(self):
        model_id = "mistralai/Mistral-7B-v0.1"
        model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16)
        orig_state_dict_keys = list(model.state_dict().keys())
        num_devices = xr.global_runtime_device_count()
        xs.set_global_mesh(xs.Mesh(np.array(range(num_devices)), (num_devices, 1), axis_names=("fsdp", "tensor")))
        def nested_wrap(model):
            layer = getattr(getattr(model, "model"), "embed_tokens")
            wrapped_layer = FSDPv2(layer)
            setattr(getattr(model, "model"), "embed_tokens", wrapped_layer)
            return FSDPv2(model)
        wrapped_model = nested_wrap(model)
        unwrapped_model_old = original_unwrap_model(wrapped_model)
        old_state_dict_keys = list(unwrapped_model_old.state_dict().keys())
        unwrapped_model_new = unwrap_model_new(wrapped_model)
        new_state_dict_keys = list(unwrapped_model_new.state_dict().keys())
        assert compare_state_dict_keys(old_state_dict_keys, orig_state_dict_keys) == False
        assert compare_state_dict_keys(new_state_dict_keys, orig_state_dict_keys) == True

# if __name__ == "__main__":
#     test_unwrap = TestUnwrap()
#     test_unwrap.test_compatibility_with_original_behavior()
#     test_unwrap.test_nested_unwrap_modules()

It can be run using:

python -m unittest test_unwrap_model.py

shub-kris avatar Mar 27 '24 07:03 shub-kris

New proposal for this, which @shub-kris's work here can still be done:

This should be merged/worked on in the following order:

  1. We're expanding this implementation into accelerate via this PR
  2. https://github.com/huggingface/transformers/pull/29933 should be merged, which brings in the Accelerate implementation instead of transformers, after we ensure that old behaviors match
  3. Afterwards, We should pass recursive=True specifically under the tpu saving portion

muellerzr avatar Mar 28 '24 16:03 muellerzr

@muellerzr how about this PR going? I found the upstreaming accelerate PR 2595 has been merged.

zorrofox avatar Apr 18 '24 07:04 zorrofox

New proposal for this, which @shub-kris's work here can still be done:

This should be merged/worked on in the following order:

  1. We're expanding this implementation into accelerate via this PR
  2. Update unwrap from accelerate #29933 should be merged, which brings in the Accelerate implementation instead of transformers, after we ensure that old behaviors match
  3. Afterwards, We should pass recursive=True specifically under the tpu saving portion

Point 1&2 both have been merged. @muellerzr can you help to go to step 3?

zorrofox avatar Apr 22 '24 03:04 zorrofox

If @shub-kris wants to rebase, the changes in trainer.py are no longer needed, and just doing recursive=True is needed thanks to https://github.com/huggingface/transformers/pull/29933.

muellerzr avatar Apr 24 '24 16:04 muellerzr