transformers icon indicating copy to clipboard operation
transformers copied to clipboard

Add support for torch.compile dynamic shapes

Open warner-benjamin opened this issue 9 months ago • 3 comments

This PR adds support for compiling models with dynamic shapes dynamic=True to almost all models with SDPAttention implementations which currently do not support dynamic shapes. #30442 added support for Llama, Gemma, OLMo, & Cohere.

The only model not modified is DBRX, which needs the changes from both #30070 and #30442 to add support for SDPA's Flash Attention kernel and support for dynamic shapes, as it I believe it suffers from the same training memory issues detailed in #30010.

As mentioned in #30442, moving the is_causal dispatch logic from inline to an if statement is required to support both fullgraph=True and dynamic=True.

I kept the qlen>1 comments but could remove them if we want to match Llama, which doesn't have it.

cc @ArthurZucker and @fxmarty

warner-benjamin avatar Apr 29 '24 23:04 warner-benjamin

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@fxmarty @ArthurZucker I didn't touch Llava. The current "Tensor-likes are not close" error shouldn't have anything to do with this PR. It should be ready to go from my end.

warner-benjamin avatar May 07 '24 17:05 warner-benjamin

no worries, rebasing on main should most probably fix this!

ArthurZucker avatar May 10 '24 08:05 ArthurZucker