RuntimeError with DirectML backend (torch.cat concat in tokenizer)
I’m running VibeVoice on an AMD GPU with torch-directml as the backend, and I hit a runtime error during inference:
\vibevoice\modular\modular_vibevoice_tokenizer.py", line 495, in _forward_streaming
full_input = torch.cat([cached_input, x], dim=2)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: The parameter is incorrect.
Steps to reproduce:
- Install torch-directml on Windows
- Run inference with device=dml.device(0) (AMD GPU). dml imported from torch_directml
- Error occurs when generating output from model
GPU: r7900 GRE Windows 11 24H2
Since we don’t have access to an AMD GPU, it’s difficult for us to reproduce and resolve your issue directly. Thanks to vatsalm1611, please try PR #51. If it works for you, let us know.
@J4R3LL could you please test PR #51 on your AMD + torch-directml setup and confirm if it fixes the concat error? Maintainers don’t have AMD hardware to reproduce.
Update: #51 fixes the DirectML concat crash
Follow-up: Opened #54 to address a separate torch-directml issue from Transformers
(attention_mask dtype overflow). It forces a boolean attention_mask before
model.generate()/forward(). Separate scope; does not affect #51.