lit-llama icon indicating copy to clipboard operation
lit-llama copied to clipboard

How to use deepspeed zero-3-offload strategy correctly? (Parameters Duplication Issue)

Open KzZheng opened this issue 2 years ago • 12 comments

Hi, I wonder how to write the code for using the deepspeed zero-3-offload strategy correctly. Currently, my code looks like:

from lightning.fabric.strategies import DeepSpeedStrategy
deep_speed = DeepSpeedStrategy(
                    stage=3,
                    offload_optimizer=True,
                    offload_parameters=True,
                )
fabric = L.Fabric(accelerator="gpu", devices=num_devices,strategy=deep_speed)

However, it seems the parameters are duplicated for all gpu. I attached the screenshot to show the GPU utilization after model, optimizer = fabric.setup(model, optimizer):

Selection_282

According to my understanding, the parameters should be distributed on different devices, right?

KzZheng avatar Apr 02 '23 22:04 KzZheng

For zero-3 with deepspeed, you should add the context manager over the model initialization:

with fabric.sharded_model():
    model = ...

Perhaps you forgot this?

awaelchli avatar Apr 03 '23 15:04 awaelchli

Thanks for your reply! Since I'm a beginner at using fabric and deepspeed, I'm not sure how to add this context manager correctly. Taking lit-llama as an example, should I write like this?

    with fabric.sharded_model():
        with lora(r=lora_r, alpha=lora_alpha, dropout=lora_dropout, enabled=True):
            model = LLaMA(config)

        checkpoint = torch.load("checkpoints/lit-llama/7B/state_dict.pth")
        
        # strict=False because missing keys due to LoRA weights not contained in checkpoint state
        model.load_state_dict(checkpoint, strict=False) 
        mark_only_lora_as_trainable(model)

I tried this way, but I encountered an error about loading the state dict:

Selection_283

I also tried to put load_state_dict() out from the fabric.sharded_model(), but the issue is the same.

Can you provide me with some hints or code references? Thanks!

KzZheng avatar Apr 03 '23 16:04 KzZheng

Hmm yes I see. A bit more work is needed here to be able to load the checkpoint in to a deepspeed sharded model. Ideally we would use fabric.load() here but for this the checkpoint would have to be a deepspeed checkpoint. I need to think how we could detect and properly load that.

awaelchli avatar Apr 04 '23 11:04 awaelchli

I am facing the same issue for lora with DeepSpeed, a bunch of size mismatch errors.

timothylimyl avatar May 05 '23 02:05 timothylimyl

Facing same issue. Should there be a conversion to deepspeed checkpoint from the existing LLAMA checkpoint?

Thanks for your reply! Since I'm a beginner at using fabric and deepspeed, I'm not sure how to add this context manager correctly. Taking lit-llama as an example, should I write like this?

    with fabric.sharded_model():
        with lora(r=lora_r, alpha=lora_alpha, dropout=lora_dropout, enabled=True):
            model = LLaMA(config)

        checkpoint = torch.load("checkpoints/lit-llama/7B/state_dict.pth")
        
        # strict=False because missing keys due to LoRA weights not contained in checkpoint state
        model.load_state_dict(checkpoint, strict=False) 
        mark_only_lora_as_trainable(model)

I tried this way, but I encountered an error about loading the state dict:

Selection_283

I also tried to put load_state_dict() out from the fabric.sharded_model(), but the issue is the same.

Can you provide me with some hints or code references? Thanks!

HeorhiiS avatar Jun 05 '23 13:06 HeorhiiS

Any updates on this?

alexgshaw avatar Jun 26 '23 23:06 alexgshaw

I was able to get the model to run by first converting the weights to deepspeed checkpoints, and then loading the model from those checkpoints.

I set deepspeed strategy as follows

deep_off = DeepSpeedStrategy(config="deep_config.json")

This was the config I used

{
    "bf16": {
        "enabled": true
    },
    "zero_optimization": {
        "stage": 3,
        "offload_optimizer": {
            "device": "cpu",
            "pin_memory": true
        },
        "offload_param": {
            "device": "cpu",
            "pin_memory": true
        },
        "overlap_comm": true,
        "contiguous_gradients": true,
        "sub_group_size": 1e9,
        "reduce_bucket_size": "auto",
        "stage3_prefetch_bucket_size": "auto",
        "stage3_param_persistence_threshold": "auto",
        "stage3_max_live_parameters": 1e9,
        "stage3_max_reuse_distance": 1e9,
        "stage3_gather_16bit_weights_on_model_save": true
    }
} 

I then started fabric with the following

fabric = L.Fabric(
        accelerator="cuda", devices=devices, precision="bf16-mixed", strategy=deep_off
    ) 

Then I loaded the checkpoints as follows

checkpoint_paths = [
        "zero_pp_rank_0_mp_rank_00_model_states.pt",
        "zero_pp_rank_1_mp_rank_00_model_states.pt",
        "zero_pp_rank_2_mp_rank_00_model_states.pt",
        "zero_pp_rank_3_mp_rank_00_model_states.pt",
        "zero_pp_rank_4_mp_rank_00_model_states.pt",
        "zero_pp_rank_5_mp_rank_00_model_states.pt",
        "zero_pp_rank_6_mp_rank_00_model_states.pt",
        "zero_pp_rank_7_mp_rank_00_model_states.pt",
    ]
    merged_checkpoint = {}
    for checkpoint_path in checkpoint_paths:
        match = re.search(r"rank_(\d+)", checkpoint_path)
        rank_num = int(match.group(1))
        if fabric.global_rank == rank_num:
            checkpoint = torch.load(checkpoint_path)
            checkpoint = {k: v for k, v in checkpoint.items() if v is not None}
            for key, value in checkpoint.items():
                if key not in merged_checkpoint:
                    merged_checkpoint[key] = value
                else:
                    try:
                        merged_checkpoint[key] += value
                    except TypeError:
                        merged_checkpoint[key].update(value)
    checkpoint = merged_checkpoint

    # with fabric.device:
    with fabric.init_module():
        torch.set_default_tensor_type(torch.HalfTensor)
        model = LLaMA(config).bfloat16()
        torch.set_default_tensor_type(torch.FloatTensor)
        model.load_state_dict(checkpoint, strict=False)
    
    optimizer = DeepSpeedCPUAdam(model.parameters(), lr=learning_rate)
    model, optimizer = fabric.setup(model, optimizer)
    train(fabric, model, optimizer, train_data, val_data, out_dir) 

And then you have to comment out the following line in the train function because it doesn't work with deepspeed

# with fabric.no_backward_sync(model, enabled=is_accumulating):

this should work, but I'm sure there is a better way to do it.

scvance avatar Jul 10 '23 21:07 scvance

@scvance I'll check it out, was it a full model checkpoint or a LoRA one?

HeorhiiS avatar Jul 11 '23 14:07 HeorhiiS

@HeorhiiS It was a full 7B model. Note that it trained slower than the normal model.

scvance avatar Jul 11 '23 15:07 scvance

Have there been any updates on this? I'm also looking at how to use DeepSpeed properly (with Mistral 7B in my case), but can't seem to find examples of usage with fabric.

WilliamGazeley avatar Nov 24 '23 07:11 WilliamGazeley

@scvance Any chance you could upload the full script you used to make this work?

WilliamGazeley avatar Nov 28 '23 20:11 WilliamGazeley

mark

qgzang avatar Jul 25 '24 12:07 qgzang