transformers icon indicating copy to clipboard operation
transformers copied to clipboard

DTensor issues when running Llama4ForConditionalGeneration with tensor parallel.

Open czkkkkkk opened this issue 4 months ago • 3 comments

System Info

transformers version: 4.52.4 pytorch version: 2.6

Who can help?

transformers version: 4.52.4 pytorch version: 2.6

When running Llama4 with tensor parallel, torch.nn.Unfold used in llama4 isn't compatible with DTensor. So I got this error: NotImplementedError: Operator aten.im2col.default does not have a sharding strategy registered. Looks like it is because the latest transformers use replicate DTensor for layers without tp_plan but Unfold isn't compatible with DTensor.

To workaround this error, I manually changed the input tensor to regular Tensor.

device_mesh = hidden_states.device_mesh if isinstance(hidden_states, DTensor) else None
placements = hidden_states.placements if isinstance(hidden_states, DTensor) else None
hidden_states = hidden_states.to_local()
hidden_states = self.unfold(hidden_states)
hidden_states = DTensor.from_local(hidden_states, device_mesh, placements)

After the change, I got AttributeError: 'BaseModelOutput' object has no attribute 'to_local' when running vision_model https://github.com/huggingface/transformers/blob/v4.52.4/src/transformers/models/llama4/modeling_llama4.py#L1543.

Information

  • [ ] The official example scripts
  • [x] My own modified scripts

Tasks

  • [ ] An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • [ ] My own task or dataset (give details below)

Reproduction

Minimal script to reproduce the error

# test.py
import torch
from transformers.models.llama4.modeling_llama4 import Llama4ForConditionalGeneration

if __name__ == '__main__':
    model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
    model = Llama4ForConditionalGeneration.from_pretrained(
        model_id,
        torch_dtype=torch.bfloat16,
        low_cpu_mem_usage=True,
        device_map="auto",
    )
    B = 1
    S = 128
    input_ids = torch.randint(0, 1000, (B, S))
    attention_mask = torch.ones((B, S))
    pixel_values = torch.randn((5, 3, 336, 336)).to(torch.bfloat16)
    model(input_ids=input_ids, attention_mask=attention_mask, pixel_values=pixel_values)
torchrun --nproc-per-node=8 test.py

Expected behavior

Expect the program to finish successfully.

czkkkkkk avatar Jun 12 '25 22:06 czkkkkkk