transformers
transformers copied to clipboard
Add support for torch.compile dynamic shapes
This PR adds support for compiling models with dynamic shapes dynamic=True
to almost all models with SDPAttention implementations which currently do not support dynamic shapes. #30442 added support for Llama, Gemma, OLMo, & Cohere.
The only model not modified is DBRX, which needs the changes from both #30070 and #30442 to add support for SDPA's Flash Attention kernel and support for dynamic shapes, as it I believe it suffers from the same training memory issues detailed in #30010.
As mentioned in #30442, moving the is_causal
dispatch logic from inline to an if statement is required to support both fullgraph=True
and dynamic=True
.
I kept the qlen>1
comments but could remove them if we want to match Llama, which doesn't have it.
cc @ArthurZucker and @fxmarty
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.
@fxmarty @ArthurZucker I didn't touch Llava. The current "Tensor-likes are not close" error shouldn't have anything to do with this PR. It should be ready to go from my end.
no worries, rebasing on main should most probably fix this!