DeepSpeed icon indicating copy to clipboard operation
DeepSpeed copied to clipboard

[BUG] Deepspeed-Inference: support AutoTP for Llama-4 models

Open songdezhao opened this issue 7 months ago • 5 comments

Describe the bug I was trying to run Deepspeed-Inference on Llama-4-Scout-Instruct for text generation purpose. The process failed when it started to load model.

I am running on a single node with 8 GPUs with 80GB of GPU memory on each: 8*80GB total. I used float16.

Here is the error message:

AutoTP:  [(<class 'transformers.models.llama4.modeling_llama4.Llama4TextDecoderLayer'>, ['shared_expert.down_proj', 'self_attn.o_proj'])]
Loading 0 checkpoint shards: 0it [00:00, ?it/s][rank0]: Traceback (most recent call last):
[rank0]:     ds_engine = deepspeed.init_inference(model, config=ds_inference_config)
[rank0]:                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/site-packages/deepspeed/__init__.py", line 364, in init_inference
[rank0]:     engine = InferenceEngine(model, config=ds_inference_config)
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/site-packages/deepspeed/inference/engine.py", line 164, in __init__
[rank0]:     self._apply_injection_policy(config, client_module)
[rank0]:   File "/usr/local/lib/python3.12/site-packages/deepspeed/inference/engine.py", line 388, in _apply_injection_policy
[rank0]:     replace_transformer_layer(client_module, self.module, checkpoint, config, self.config)
[rank0]:   File "/usr/local/lib/python3.12/site-packages/deepspeed/module_inject/replace_module.py", line 397, in replace_transformer_layer
[rank0]:     if 'Yuan' in str(replaced_module):
[rank0]:                      ^^^^^^^^^^^^^^^
[rank0]: UnboundLocalError: cannot access local variable 'replaced_module' where it is not associated with a value

To Reproduce Steps to reproduce the behavior:

  1. Simple inference script to reproduce Here is a code snippet to reproduce the error:
import glob
import os

import deepspeed
import torch
from transformers import Llama4TextConfig, AutoModelForCausalLM

model_path = "./model_to_evaluate"

kwargs = {"torch_dtype": torch.float16, "attn_implementation": "sdpa"}
print(f"kwargs: {kwargs}")

model_config = Llama4TextConfig.from_pretrained(model_path, **kwargs)

# load model using meta device
with deepspeed.OnDevice(dtype=kwargs["torch_dtype"], device="meta", enabled=True):
    model = AutoModelForCausalLM.from_config(model_config, **kwargs)
print(f"model device: {next(model.parameters()).device}")

# set up deepspeed inference config
ds_inference_config = {
    "dtype": kwargs["torch_dtype"],
    # meta device is not compatible with kernel injection
    "replace_with_kernel_inject": False,
    # tp equals to the global number of gpus
    "tensor_parallel": {
        "tp_size": int(os.getenv("WORLD_SIZE", "1"))
    },
    # specify where the model files are
    "checkpoint": {
        "checkpoints": glob.glob(os.path.join(model_path, "**", "*" + ".safetensors"), recursive=False),
        "type": "DS_MODEL",
        "version": 1.0
    }
}
print(f"deepspeed inference config: {ds_inference_config}")

ds_engine = deepspeed.init_inference(model, config=ds_inference_config)
model = ds_engine.module
model.eval()

  1. What packages are required and their versions
transformers==4.51.3
accelerate==1.6.0
deepspeed==0.16.7
flash-attn==2.7.3 (this is optional since I ran with sdpa)
torch==2.6.0+cu126
  1. How to run the script Step 1: Download meta-llama/Llama-4-Scout-17B-16E-Instruct to local directory "model_to_evaluate". Step 2: Put the above code to file "llama4_dsi.py". Step 3: Run this: accelerate launch llama4_dsi.py

  2. ... Expected behavior Expecting the model to successfully load and be distributed across the 8 GPUs.

ds_report output Please run ds_report to give us details about your setup.

Screenshots If applicable, add screenshots to help explain your problem.

System info (please complete the following information):

  • OS: [e.g. Ubuntu 18.04]
  • GPU count and types [e.g. two machines with x8 A100s each]
  • (if applicable) what DeepSpeed-MII version are you using
  • (if applicable) Hugging Face Transformers/Accelerate/etc. versions
  • Python version
  • Any other relevant info about your setup

Docker context Are you using a specific docker image that you can share?

Additional context Add any other context about the problem here.

songdezhao avatar May 10 '25 19:05 songdezhao

After reading your example and deepspeed code, I think the main reanson is that you specify the config "checkpoint",

"checkpoints": glob.glob(os.path.join(model_path, "**", "*" + ".safetensors"), recursive=False),

the root cause is this line code

checkpoint = checkpoint_dict["checkpoints"]

checkpoint_dict is not None but checkpoint_dict["checkpoints"] is None, then replaced_module can not be initialized. https://github.com/deepspeedai/DeepSpeed/blob/master/deepspeed/module_inject/replace_module.py#L384

then this example will throw error

[rank0]: UnboundLocalError: cannot access local variable 'replaced_module' where it is not associated with a value

can you add replaced_module=None before L384, then add condition like this

if replaced_module is not None and 'Yuan' in str(replaced_module):

https://github.com/deepspeedai/DeepSpeed/blob/master/deepspeed/module_inject/replace_module.py#L397

ranzhejiang avatar May 12 '25 06:05 ranzhejiang

There was an issue in my original code. I now do the following and ds_inference_config["checkpoint"]["checkpoints"] now contains the list of safetensor files:

# recursive should be True
"checkpoints": glob.glob(os.path.join(model_path, "**", "*" + ".safetensors"), recursive=True),

However, now, I am getting this error and it looks like AutoTP is not supported for Llama-4 models? Could we add AutoTP support for Llama-4?

AutoTP:  [(<class 'transformers.models.llama4.modeling_llama4.Llama4TextDecoderLayer'>, ['self_attn.o_proj', 'shared_expert.down_proj'])]
Loading 50 checkpoint shards:   0%|          | 0/50 [00:00<?, ?it/s]
model device: meta
deepspeed inference config: {'dtype': torch.float16, 'replace_with_kernel_inject': False, 'tensor_parallel': {'tp_size': 8}, 'checkpoint': {'checkpoints': ['./model_to_evaluate/model-00045-of-00050.safetensors', './model_to_evaluate/model-00034-of-00050.safetensors', './model_to_evaluate/model-00005-of-00050.safetensors', './model_to_evaluate/model-00032-of-00050.safetensors', './model_to_evaluate/model-00024-of-00050.safetensors', './model_to_evaluate/model-00040-of-00050.safetensors', './model_to_evaluate/model-00007-of-00050.safetensors', './model_to_evaluate/model-00038-of-00050.safetensors', './model_to_evaluate/model-00046-of-00050.safetensors', './model_to_evaluate/model-00050-of-00050.safetensors', './model_to_evaluate/model-00009-of-00050.safetensors', './model_to_evaluate/model-00033-of-00050.safetensors', './model_to_evaluate/model-00003-of-00050.safetensors', './model_to_evaluate/model-00025-of-00050.safetensors', './model_to_evaluate/model-00031-of-00050.safetensors', './model_to_evaluate/model-00006-of-00050.safetensors', './model_to_evaluate/model-00047-of-00050.safetensors', './model_to_evaluate/model-00011-of-00050.safetensors', './model_to_evaluate/model-00049-of-00050.safetensors', './model_to_evaluate/model-00041-of-00050.safetensors', './model_to_evaluate/model-00039-of-00050.safetensors', './model_to_evaluate/model-00004-of-00050.safetensors', './model_to_evaluate/model-00008-of-00050.safetensors', './model_to_evaluate/model-00014-of-00050.safetensors', './model_to_evaluate/model-00017-of-00050.safetensors', './model_to_evaluate/model-00023-of-00050.safetensors', './model_to_evaluate/model-00027-of-00050.safetensors', './model_to_evaluate/model-00013-of-00050.safetensors', './model_to_evaluate/model-00018-of-00050.safetensors', './model_to_evaluate/model-00042-of-00050.safetensors', './model_to_evaluate/model-00048-of-00050.safetensors', './model_to_evaluate/model-00001-of-00050.safetensors', './model_to_evaluate/model-00029-of-00050.safetensors', './model_to_evaluate/model-00036-of-00050.safetensors', './model_to_evaluate/model-00012-of-00050.safetensors', './model_to_evaluate/model-00021-of-00050.safetensors', './model_to_evaluate/model-00020-of-00050.safetensors', './model_to_evaluate/model-00019-of-00050.safetensors', './model_to_evaluate/model-00043-of-00050.safetensors', './model_to_evaluate/model-00026-of-00050.safetensors', './model_to_evaluate/model-00022-of-00050.safetensors', './model_to_evaluate/model-00030-of-00050.safetensors', './model_to_evaluate/model-00015-of-00050.safetensors', './model_to_evaluate/model-00044-of-00050.safetensors', './model_to_evaluate/model-00002-of-00050.safetensors', './model_to_evaluate/model-00037-of-00050.safetensors', './model_to_evaluate/model-00016-of-00050.safetensors', './model_to_evaluate/model-00035-of-00050.safetensors', './model_to_evaluate/model-00028-of-00050.safetensors', './model_to_evaluate/model-00010-of-00050.safetensors'], 'type': 'DS_MODEL', 'version': 1.0}}
[rank0]: Traceback (most recent call last):
[rank0]:     ds_engine = deepspeed.init_inference(model, config=ds_inference_config)
[rank0]:                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/site-packages/deepspeed/__init__.py", line 364, in init_inference
[rank0]:     engine = InferenceEngine(model, config=ds_inference_config)
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/site-packages/deepspeed/inference/engine.py", line 164, in __init__
[rank0]:     self._apply_injection_policy(config, client_module)
[rank0]:   File "/usr/local/lib/python3.12/site-packages/deepspeed/inference/engine.py", line 388, in _apply_injection_policy
[rank0]:     replace_transformer_layer(client_module, self.module, checkpoint, config, self.config)
[rank0]:   File "/usr/local/lib/python3.12/site-packages/deepspeed/module_inject/replace_module.py", line 388, in replace_transformer_layer
[rank0]:     replaced_module = replace_module(model=model,
[rank0]:                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/site-packages/deepspeed/module_inject/replace_module.py", line 653, in replace_module
[rank0]:     replaced_module, _ = _replace_module(model, policy, state_dict=sd)
[rank0]:                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/site-packages/deepspeed/module_inject/replace_module.py", line 713, in _replace_module
[rank0]:     _, layer_id = _replace_module(child,
[rank0]:                   ^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/site-packages/deepspeed/module_inject/replace_module.py", line 713, in _replace_module
[rank0]:     _, layer_id = _replace_module(child,
[rank0]:                   ^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/site-packages/deepspeed/module_inject/replace_module.py", line 689, in _replace_module
[rank0]:     replaced_module = policies[child.__class__][0](child,
[rank0]:                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/site-packages/deepspeed/module_inject/replace_module.py", line 333, in replace_fn
[rank0]:     new_module = replace_wo_policy(child, _policy, prefix=prefix, state_dict=state_dict)
[rank0]:                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/site-packages/deepspeed/module_inject/replace_module.py", line 316, in replace_wo_policy
[rank0]:     return _autotp._replace_module(module)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/site-packages/deepspeed/module_inject/auto_tp.py", line 481, in _replace_module
[rank0]:     self._replace_module(child, name, class_name)
[rank0]:   File "/usr/local/lib/python3.12/site-packages/deepspeed/module_inject/auto_tp.py", line 460, in _replace_module
[rank0]:     Loading.load(child, self.state_dict, checking_key, self.mp_group)
[rank0]:   File "/usr/local/lib/python3.12/site-packages/deepspeed/module_inject/auto_tp.py", line 166, in load
[rank0]:     module.weight = mp_replace.copy(module.weight.data, state_dict[prefix + 'weight'])
[rank0]:                                                         ~~~~~~~~~~^^^^^^^^^^^^^^^^^^^
[rank0]: KeyError: 'model.layers.43.feed_forward.router.weight'

songdezhao avatar May 12 '25 15:05 songdezhao

@inkcherry can you add autoTP support for llama4 both inference and training?

ranzhejiang avatar May 12 '25 21:05 ranzhejiang

FYI @delock @Yejing-Lai

inkcherry avatar May 13 '25 05:05 inkcherry

Hi @ranzhejiang llama4 had not been supported by AutoTP. From the error message there seem to have a key mismatch. @songdezhao Do you have a dump of module structure?

delock avatar May 14 '25 09:05 delock