Updates to support new sdpa function
For my understanding, is it possible for SDPA to detect that the input key/value shapes are targeting GQA so that we do not need to pass enable_gqa=True?
In this case, the key/value shapes are not broadcastable, so there should not be any uncertainty for semantics.
@awgu Had much discussion on this one with different folks around the org. TLDR yes its completely possible to recognize when users want to do this and enable it if the shapes work. This however doesnt fit naturally into the existing broadcasting semantics. So in theory it is possible for users to "make a mistake" pass in mishaped inputs and not get an error. The consensus was to add this extra check so that users give strong single of their intention.
Hey @drisspg @jainapurva. Your benchmarks here are really helpful - this looks promising. I'm looking at integrating this into torchtune since we also currently expand+reshape. I'm wondering if there's any gotchas I should be watching out for here, or any reasons not to be using this - since this didn't end up landing in torchtitan?