[BUG] Deepspeed-Inference: support AutoTP for Llama-4 models
Describe the bug I was trying to run Deepspeed-Inference on Llama-4-Scout-Instruct for text generation purpose. The process failed when it started to load model.
I am running on a single node with 8 GPUs with 80GB of GPU memory on each: 8*80GB total. I used float16.
Here is the error message:
AutoTP: [(<class 'transformers.models.llama4.modeling_llama4.Llama4TextDecoderLayer'>, ['shared_expert.down_proj', 'self_attn.o_proj'])]
Loading 0 checkpoint shards: 0it [00:00, ?it/s][rank0]: Traceback (most recent call last):
[rank0]: ds_engine = deepspeed.init_inference(model, config=ds_inference_config)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/usr/local/lib/python3.12/site-packages/deepspeed/__init__.py", line 364, in init_inference
[rank0]: engine = InferenceEngine(model, config=ds_inference_config)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/usr/local/lib/python3.12/site-packages/deepspeed/inference/engine.py", line 164, in __init__
[rank0]: self._apply_injection_policy(config, client_module)
[rank0]: File "/usr/local/lib/python3.12/site-packages/deepspeed/inference/engine.py", line 388, in _apply_injection_policy
[rank0]: replace_transformer_layer(client_module, self.module, checkpoint, config, self.config)
[rank0]: File "/usr/local/lib/python3.12/site-packages/deepspeed/module_inject/replace_module.py", line 397, in replace_transformer_layer
[rank0]: if 'Yuan' in str(replaced_module):
[rank0]: ^^^^^^^^^^^^^^^
[rank0]: UnboundLocalError: cannot access local variable 'replaced_module' where it is not associated with a value
To Reproduce Steps to reproduce the behavior:
- Simple inference script to reproduce Here is a code snippet to reproduce the error:
import glob
import os
import deepspeed
import torch
from transformers import Llama4TextConfig, AutoModelForCausalLM
model_path = "./model_to_evaluate"
kwargs = {"torch_dtype": torch.float16, "attn_implementation": "sdpa"}
print(f"kwargs: {kwargs}")
model_config = Llama4TextConfig.from_pretrained(model_path, **kwargs)
# load model using meta device
with deepspeed.OnDevice(dtype=kwargs["torch_dtype"], device="meta", enabled=True):
model = AutoModelForCausalLM.from_config(model_config, **kwargs)
print(f"model device: {next(model.parameters()).device}")
# set up deepspeed inference config
ds_inference_config = {
"dtype": kwargs["torch_dtype"],
# meta device is not compatible with kernel injection
"replace_with_kernel_inject": False,
# tp equals to the global number of gpus
"tensor_parallel": {
"tp_size": int(os.getenv("WORLD_SIZE", "1"))
},
# specify where the model files are
"checkpoint": {
"checkpoints": glob.glob(os.path.join(model_path, "**", "*" + ".safetensors"), recursive=False),
"type": "DS_MODEL",
"version": 1.0
}
}
print(f"deepspeed inference config: {ds_inference_config}")
ds_engine = deepspeed.init_inference(model, config=ds_inference_config)
model = ds_engine.module
model.eval()
- What packages are required and their versions
transformers==4.51.3
accelerate==1.6.0
deepspeed==0.16.7
flash-attn==2.7.3 (this is optional since I ran with sdpa)
torch==2.6.0+cu126
-
How to run the script Step 1: Download meta-llama/Llama-4-Scout-17B-16E-Instruct to local directory "model_to_evaluate". Step 2: Put the above code to file "llama4_dsi.py". Step 3: Run this: accelerate launch llama4_dsi.py
-
... Expected behavior Expecting the model to successfully load and be distributed across the 8 GPUs.
ds_report output
Please run ds_report to give us details about your setup.
Screenshots If applicable, add screenshots to help explain your problem.
System info (please complete the following information):
- OS: [e.g. Ubuntu 18.04]
- GPU count and types [e.g. two machines with x8 A100s each]
- (if applicable) what DeepSpeed-MII version are you using
- (if applicable) Hugging Face Transformers/Accelerate/etc. versions
- Python version
- Any other relevant info about your setup
Docker context Are you using a specific docker image that you can share?
Additional context Add any other context about the problem here.
After reading your example and deepspeed code, I think the main reanson is that you specify the config "checkpoint",
"checkpoints": glob.glob(os.path.join(model_path, "**", "*" + ".safetensors"), recursive=False),
the root cause is this line code
checkpoint = checkpoint_dict["checkpoints"]
checkpoint_dict is not None but checkpoint_dict["checkpoints"] is None, then replaced_module can not be initialized. https://github.com/deepspeedai/DeepSpeed/blob/master/deepspeed/module_inject/replace_module.py#L384
then this example will throw error
[rank0]: UnboundLocalError: cannot access local variable 'replaced_module' where it is not associated with a value
can you add replaced_module=None before L384, then add condition like this
if replaced_module is not None and 'Yuan' in str(replaced_module):
https://github.com/deepspeedai/DeepSpeed/blob/master/deepspeed/module_inject/replace_module.py#L397
There was an issue in my original code. I now do the following and ds_inference_config["checkpoint"]["checkpoints"] now contains the list of safetensor files:
# recursive should be True
"checkpoints": glob.glob(os.path.join(model_path, "**", "*" + ".safetensors"), recursive=True),
However, now, I am getting this error and it looks like AutoTP is not supported for Llama-4 models? Could we add AutoTP support for Llama-4?
AutoTP: [(<class 'transformers.models.llama4.modeling_llama4.Llama4TextDecoderLayer'>, ['self_attn.o_proj', 'shared_expert.down_proj'])]
Loading 50 checkpoint shards: 0%| | 0/50 [00:00<?, ?it/s]
model device: meta
deepspeed inference config: {'dtype': torch.float16, 'replace_with_kernel_inject': False, 'tensor_parallel': {'tp_size': 8}, 'checkpoint': {'checkpoints': ['./model_to_evaluate/model-00045-of-00050.safetensors', './model_to_evaluate/model-00034-of-00050.safetensors', './model_to_evaluate/model-00005-of-00050.safetensors', './model_to_evaluate/model-00032-of-00050.safetensors', './model_to_evaluate/model-00024-of-00050.safetensors', './model_to_evaluate/model-00040-of-00050.safetensors', './model_to_evaluate/model-00007-of-00050.safetensors', './model_to_evaluate/model-00038-of-00050.safetensors', './model_to_evaluate/model-00046-of-00050.safetensors', './model_to_evaluate/model-00050-of-00050.safetensors', './model_to_evaluate/model-00009-of-00050.safetensors', './model_to_evaluate/model-00033-of-00050.safetensors', './model_to_evaluate/model-00003-of-00050.safetensors', './model_to_evaluate/model-00025-of-00050.safetensors', './model_to_evaluate/model-00031-of-00050.safetensors', './model_to_evaluate/model-00006-of-00050.safetensors', './model_to_evaluate/model-00047-of-00050.safetensors', './model_to_evaluate/model-00011-of-00050.safetensors', './model_to_evaluate/model-00049-of-00050.safetensors', './model_to_evaluate/model-00041-of-00050.safetensors', './model_to_evaluate/model-00039-of-00050.safetensors', './model_to_evaluate/model-00004-of-00050.safetensors', './model_to_evaluate/model-00008-of-00050.safetensors', './model_to_evaluate/model-00014-of-00050.safetensors', './model_to_evaluate/model-00017-of-00050.safetensors', './model_to_evaluate/model-00023-of-00050.safetensors', './model_to_evaluate/model-00027-of-00050.safetensors', './model_to_evaluate/model-00013-of-00050.safetensors', './model_to_evaluate/model-00018-of-00050.safetensors', './model_to_evaluate/model-00042-of-00050.safetensors', './model_to_evaluate/model-00048-of-00050.safetensors', './model_to_evaluate/model-00001-of-00050.safetensors', './model_to_evaluate/model-00029-of-00050.safetensors', './model_to_evaluate/model-00036-of-00050.safetensors', './model_to_evaluate/model-00012-of-00050.safetensors', './model_to_evaluate/model-00021-of-00050.safetensors', './model_to_evaluate/model-00020-of-00050.safetensors', './model_to_evaluate/model-00019-of-00050.safetensors', './model_to_evaluate/model-00043-of-00050.safetensors', './model_to_evaluate/model-00026-of-00050.safetensors', './model_to_evaluate/model-00022-of-00050.safetensors', './model_to_evaluate/model-00030-of-00050.safetensors', './model_to_evaluate/model-00015-of-00050.safetensors', './model_to_evaluate/model-00044-of-00050.safetensors', './model_to_evaluate/model-00002-of-00050.safetensors', './model_to_evaluate/model-00037-of-00050.safetensors', './model_to_evaluate/model-00016-of-00050.safetensors', './model_to_evaluate/model-00035-of-00050.safetensors', './model_to_evaluate/model-00028-of-00050.safetensors', './model_to_evaluate/model-00010-of-00050.safetensors'], 'type': 'DS_MODEL', 'version': 1.0}}
[rank0]: Traceback (most recent call last):
[rank0]: ds_engine = deepspeed.init_inference(model, config=ds_inference_config)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/usr/local/lib/python3.12/site-packages/deepspeed/__init__.py", line 364, in init_inference
[rank0]: engine = InferenceEngine(model, config=ds_inference_config)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/usr/local/lib/python3.12/site-packages/deepspeed/inference/engine.py", line 164, in __init__
[rank0]: self._apply_injection_policy(config, client_module)
[rank0]: File "/usr/local/lib/python3.12/site-packages/deepspeed/inference/engine.py", line 388, in _apply_injection_policy
[rank0]: replace_transformer_layer(client_module, self.module, checkpoint, config, self.config)
[rank0]: File "/usr/local/lib/python3.12/site-packages/deepspeed/module_inject/replace_module.py", line 388, in replace_transformer_layer
[rank0]: replaced_module = replace_module(model=model,
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/usr/local/lib/python3.12/site-packages/deepspeed/module_inject/replace_module.py", line 653, in replace_module
[rank0]: replaced_module, _ = _replace_module(model, policy, state_dict=sd)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/usr/local/lib/python3.12/site-packages/deepspeed/module_inject/replace_module.py", line 713, in _replace_module
[rank0]: _, layer_id = _replace_module(child,
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/usr/local/lib/python3.12/site-packages/deepspeed/module_inject/replace_module.py", line 713, in _replace_module
[rank0]: _, layer_id = _replace_module(child,
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/usr/local/lib/python3.12/site-packages/deepspeed/module_inject/replace_module.py", line 689, in _replace_module
[rank0]: replaced_module = policies[child.__class__][0](child,
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/usr/local/lib/python3.12/site-packages/deepspeed/module_inject/replace_module.py", line 333, in replace_fn
[rank0]: new_module = replace_wo_policy(child, _policy, prefix=prefix, state_dict=state_dict)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/usr/local/lib/python3.12/site-packages/deepspeed/module_inject/replace_module.py", line 316, in replace_wo_policy
[rank0]: return _autotp._replace_module(module)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/usr/local/lib/python3.12/site-packages/deepspeed/module_inject/auto_tp.py", line 481, in _replace_module
[rank0]: self._replace_module(child, name, class_name)
[rank0]: File "/usr/local/lib/python3.12/site-packages/deepspeed/module_inject/auto_tp.py", line 460, in _replace_module
[rank0]: Loading.load(child, self.state_dict, checking_key, self.mp_group)
[rank0]: File "/usr/local/lib/python3.12/site-packages/deepspeed/module_inject/auto_tp.py", line 166, in load
[rank0]: module.weight = mp_replace.copy(module.weight.data, state_dict[prefix + 'weight'])
[rank0]: ~~~~~~~~~~^^^^^^^^^^^^^^^^^^^
[rank0]: KeyError: 'model.layers.43.feed_forward.router.weight'
@inkcherry can you add autoTP support for llama4 both inference and training?
FYI @delock @Yejing-Lai
Hi @ranzhejiang llama4 had not been supported by AutoTP. From the error message there seem to have a key mismatch. @songdezhao Do you have a dump of module structure?