transformers
transformers copied to clipboard
fix: extend the unwrap_model function and save unwrapped model state dict instead of wrapped
What does this PR do?
This PR pushes two changes:
- Save the
unwrap_model(model).state_dict()
wheneverif isinstance(unwrap_model(model), supported_classes)
- Extend the
unwrap_model()
so that any wrapper on the children layer of model can also be unwrapped correctly.
With the existing unwrap_model()
only the outermost layer is unwrapped and it fails when we use wrapping with fsdp
as it doesn't go through the children layers or modules.
For example:
A Wrapped Model
SpmdFullyShardedDataParallel(
(_orig_module): GemmaForCausalLM(
(model): GemmaModel(
(embed_tokens): Embedding(256000, 2048, padding_idx=0)
(layers): ModuleList(
(0-17): 18 x SpmdFullyShardedDataParallel(
(_orig_module): GemmaDecoderLayer(
(self_attn): GemmaAttention(
(q_proj): Linear(in_features=2048, out_features=2048, bias=False)
(k_proj): Linear(in_features=2048, out_features=256, bias=False)
(v_proj): Linear(in_features=2048, out_features=256, bias=False)
(o_proj): Linear(in_features=2048, out_features=2048, bias=False)
(rotary_emb): GemmaRotaryEmbedding()
)
(mlp): GemmaMLP(
(gate_proj): Linear(in_features=2048, out_features=16384, bias=False)
(up_proj): Linear(in_features=2048, out_features=16384, bias=False)
(down_proj): Linear(in_features=16384, out_features=2048, bias=False)
(act_fn): PytorchGELUTanh()
)
(input_layernorm): GemmaRMSNorm()
(post_attention_layernorm): GemmaRMSNorm()
)
)
)
(norm): GemmaRMSNorm()
)
(lm_head): Linear(in_features=2048, out_features=256000, bias=False)
)
)
When unwrapped using existing unwrap_model()
leads to
GemmaForCausalLM(
(model): GemmaModel(
(embed_tokens): Embedding(256000, 2048, padding_idx=0)
(layers): ModuleList(
(0-17): 18 x SpmdFullyShardedDataParallel(
(_orig_module): GemmaDecoderLayer(
(self_attn): GemmaAttention(
(q_proj): Linear(in_features=2048, out_features=2048, bias=False)
(k_proj): Linear(in_features=2048, out_features=256, bias=False)
(v_proj): Linear(in_features=2048, out_features=256, bias=False)
(o_proj): Linear(in_features=2048, out_features=2048, bias=False)
(rotary_emb): GemmaRotaryEmbedding()
)
(mlp): GemmaMLP(
(gate_proj): Linear(in_features=2048, out_features=16384, bias=False)
(up_proj): Linear(in_features=2048, out_features=16384, bias=False)
(down_proj): Linear(in_features=16384, out_features=2048, bias=False)
(act_fn): PytorchGELUTanh()
)
(input_layernorm): GemmaRMSNorm()
(post_attention_layernorm): GemmaRMSNorm()
)
)
)
(norm): GemmaRMSNorm()
)
(lm_head): Linear(in_features=2048, out_features=256000, bias=False)
)
But when using the change mentioned in this repo:
GemmaForCausalLM(
(model): GemmaModel(
(embed_tokens): Embedding(256000, 2048, padding_idx=0)
(layers): ModuleList(
(0-17): 18 x GemmaDecoderLayer(
(self_attn): GemmaAttention(
(q_proj): Linear(in_features=2048, out_features=2048, bias=False)
(k_proj): Linear(in_features=2048, out_features=256, bias=False)
(v_proj): Linear(in_features=2048, out_features=256, bias=False)
(o_proj): Linear(in_features=2048, out_features=2048, bias=False)
(rotary_emb): GemmaRotaryEmbedding()
)
(mlp): GemmaMLP(
(gate_proj): Linear(in_features=2048, out_features=16384, bias=False)
(up_proj): Linear(in_features=2048, out_features=16384, bias=False)
(down_proj): Linear(in_features=16384, out_features=2048, bias=False)
(act_fn): PytorchGELUTanh()
)
(input_layernorm): GemmaRMSNorm()
(post_attention_layernorm): GemmaRMSNorm()
)
)
(norm): GemmaRMSNorm()
)
(lm_head): Linear(in_features=2048, out_features=256000, bias=False)
)
Fixes #29659
Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
- [ ] Did you read the contributor guideline, Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the forum? Please add a link to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes? Here are the documentation guidelines, and here are tips on formatting docstrings.
- [ ] Did you write any new necessary tests?
Who can review?
@amyeroberts @muellerzr @pacman100 Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR.
@alanwaketan can you also take a look please ?
You can replicate the wrapping and unwrapping using this script:
import torch
import torch_xla
import torch.nn as nn
import functools
from transformers import AutoModelForCausalLM
from torch_xla.experimental.spmd_fully_sharded_data_parallel import SpmdFullyShardedDataParallel as FSDPv2
from torch_xla.distributed.fsdp.wrap import (
transformer_auto_wrap_policy,
)
import torch_xla.distributed.spmd as xs
import torch_xla.runtime as xr
import numpy as np
from torch_xla.distributed.fsdp import checkpoint_module
from transformers.trainer_pt_utils import get_module_class_from_name
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.modeling_utils import unwrap_model
def wrap_model(model, fsdp_config):
num_devices = xr.global_runtime_device_count()
xs.set_global_mesh(xs.Mesh(np.array(range(num_devices)), (num_devices, 1), axis_names=("fsdp", "tensor")))
auto_wrap_policy = None
auto_wrapper_callable = None
default_transformer_cls_names_to_wrap = getattr(model, "_no_split_modules", None)
fsdp_transformer_layer_cls_to_wrap = fsdp_config.get(
"transformer_layer_cls_to_wrap", default_transformer_cls_names_to_wrap
)
if fsdp_transformer_layer_cls_to_wrap is not None:
transformer_cls_to_wrap = set()
for layer_class in fsdp_transformer_layer_cls_to_wrap:
print(f"layer class is {layer_class}")
transformer_cls = get_module_class_from_name(model, layer_class)
if transformer_cls is None:
raise Exception("Could not find the transformer layer class to wrap in the model.")
else:
transformer_cls_to_wrap.add(transformer_cls)
print(f"transformer_cls_to_wrap: {transformer_cls_to_wrap}")
auto_wrap_policy = functools.partial(
transformer_auto_wrap_policy,
# Transformer layer class to wrap
transformer_layer_cls=transformer_cls_to_wrap,
)
if fsdp_config["xla_fsdp_grad_ckpt"]:
# Apply gradient checkpointing to auto-wrapped sub-modules if specified
def auto_wrapper_callable(m, *args, **kwargs):
target_cls = FSDPv2
return target_cls(checkpoint_module(m), *args, **kwargs)
def shard_output(output, mesh):
real_output = None
if isinstance(output, torch.Tensor):
real_output = output
elif isinstance(output, tuple):
real_output = output[0]
elif isinstance(output, CausalLMOutputWithPast):
real_output = output.logits
if real_output is None:
raise ValueError("Something went wrong, the output of the model shouldn't be `None`")
xs.mark_sharding(real_output, mesh, ("fsdp", None, None))
print(f"auto wrap policy is {auto_wrap_policy}")
print(f"auto wrapper callable is {auto_wrapper_callable}")
model = FSDPv2(
model,
shard_output=shard_output,
auto_wrap_policy=auto_wrap_policy,
auto_wrapper_callable=auto_wrapper_callable,
)
return model
def unwrap_model_new(model: nn.Module) -> nn.Module:
"""
Recursively unwraps a module and its child sublayers.
Args:
model (nn.Module): Module to unwrap.
Returns:
nn.Module: The unwrapped module.
"""
def recursive_unwrap(module):
if hasattr(module, "module"):
try:
unwrapped_module = recursive_unwrap(getattr(module, "module"))
except AttributeError:
unwrapped_module = module # Handle cases where wrapped module is inaccessible
return unwrapped_module
# Unwrap child sublayers recursively
for name, child in module.named_children():
setattr(module, name, recursive_unwrap(child))
return module
# Start with top-level unwrapping
unwrapped_model = recursive_unwrap(model)
return unwrapped_model
def main():
model_id = "google/gemma-2b"
# Load model and tokenizer
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16)
fsdp_config = {
"fsdp_transformer_layer_cls_to_wrap": ["GemmaDecoderLayer"],
"xla": True,
"xla_fsdp_v2": True,
"xla_fsdp_grad_ckpt": True,
}
wrapped_model = wrap_model(model, fsdp_config)
print(wrapped_model)
unwrapped_model_old = unwrap_model(wrapped_model)
print(unwrapped_model_old)
unwrapped_model_new = unwrap_model_new(wrapped_model)
print(unwrapped_model_new)
if __name__ == "__main__":
main()
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.
@amyeroberts I had to change the unwrap_model
because of the changes introduced here: #28949 which was Support PyTorch/XLA FSDP via SPMD
and the existing unwrap_model
only fails there. I can write a test, but the problem is it requires TPU and I am not sure if we have that as a part of our CI runner?
So, how should we proceed here?
@amyeroberts here is a small snippet for the test:
import torch
import torch_xla
import torch.nn as nn
from transformers import AutoModelForCausalLM
from torch_xla.experimental.spmd_fully_sharded_data_parallel import SpmdFullyShardedDataParallel as FSDPv2
import torch_xla.distributed.spmd as xs
import torch_xla.runtime as xr
import numpy as np
import unittest
def compare_state_dict_keys(state_dict_keys_model1, state_dict_keys_model2):
for key1, key2 in zip(state_dict_keys_model1, state_dict_keys_model2):
if key1 != key2:
# print(f"Keys are not equal")
# print(key1, key2)
return False
return True
# Original `unwrap_model` function
def original_unwrap_model(model: nn.Module) -> nn.Module:
"""Original unwrap implementation for comparison."""
if hasattr(model, "module"):
return original_unwrap_model(model.module)
else:
return model
def unwrap_model_new(model: nn.Module) -> nn.Module:
"""
Recursively unwraps a module and its child sublayers.
Args:
model (nn.Module): Module to unwrap.
Returns:
nn.Module: The unwrapped module.
"""
def recursive_unwrap(module):
if hasattr(module, "module"):
unwrapped_module = recursive_unwrap(getattr(module, "module"))
else:
unwrapped_module = module # Handle cases where wrapped module is inaccessible
# Unwrap child sublayers recursively
for name, child in module.named_children():
setattr(module, name, recursive_unwrap(child))
return unwrapped_module
# Start with top-level unwrapping
unwrapped_model = recursive_unwrap(model)
return unwrapped_model
class TestUnwrap(unittest.TestCase):
def test_compatibility_with_original_behavior(self):
model_id = "mistralai/Mistral-7B-v0.1"
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16)
num_devices = xr.global_runtime_device_count()
xs.set_global_mesh(xs.Mesh(np.array(range(num_devices)), (num_devices, 1), axis_names=("fsdp", "tensor")))
wrapped_model = FSDPv2(model)
unwrapped_model_old = original_unwrap_model(wrapped_model)
state_dict_keys_model1 = list(unwrapped_model_old.state_dict().keys())
unwrapped_model_new = unwrap_model_new(wrapped_model)
state_dict_keys_model2 = list(unwrapped_model_new.state_dict().keys())
assert compare_state_dict_keys(state_dict_keys_model1, state_dict_keys_model2) == True
def test_nested_unwrap_modules(self):
model_id = "mistralai/Mistral-7B-v0.1"
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16)
orig_state_dict_keys = list(model.state_dict().keys())
num_devices = xr.global_runtime_device_count()
xs.set_global_mesh(xs.Mesh(np.array(range(num_devices)), (num_devices, 1), axis_names=("fsdp", "tensor")))
def nested_wrap(model):
layer = getattr(getattr(model, "model"), "embed_tokens")
wrapped_layer = FSDPv2(layer)
setattr(getattr(model, "model"), "embed_tokens", wrapped_layer)
return FSDPv2(model)
wrapped_model = nested_wrap(model)
unwrapped_model_old = original_unwrap_model(wrapped_model)
old_state_dict_keys = list(unwrapped_model_old.state_dict().keys())
unwrapped_model_new = unwrap_model_new(wrapped_model)
new_state_dict_keys = list(unwrapped_model_new.state_dict().keys())
assert compare_state_dict_keys(old_state_dict_keys, orig_state_dict_keys) == False
assert compare_state_dict_keys(new_state_dict_keys, orig_state_dict_keys) == True
# if __name__ == "__main__":
# test_unwrap = TestUnwrap()
# test_unwrap.test_compatibility_with_original_behavior()
# test_unwrap.test_nested_unwrap_modules()
It can be run using:
python -m unittest test_unwrap_model.py
New proposal for this, which @shub-kris's work here can still be done:
This should be merged/worked on in the following order:
- We're expanding this implementation into accelerate via this PR
- https://github.com/huggingface/transformers/pull/29933 should be merged, which brings in the Accelerate implementation instead of
transformers
, after we ensure that old behaviors match - Afterwards, We should pass
recursive=True
specifically under the tpu saving portion
@muellerzr how about this PR going? I found the upstreaming accelerate
PR 2595 has been merged.
New proposal for this, which @shub-kris's work here can still be done:
This should be merged/worked on in the following order:
- We're expanding this implementation into accelerate via this PR
- Update unwrap from accelerate #29933 should be merged, which brings in the Accelerate implementation instead of
transformers
, after we ensure that old behaviors match- Afterwards, We should pass
recursive=True
specifically under the tpu saving portion
Point 1&2 both have been merged. @muellerzr can you help to go to step 3?
If @shub-kris wants to rebase, the changes in trainer.py
are no longer needed, and just doing recursive=True
is needed thanks to https://github.com/huggingface/transformers/pull/29933.