lightning-thunder
lightning-thunder copied to clipboard
Add `torch.nn.functional.scaled_grouped_mm`
What does this PR do?
As per title. This PR adds torch.nn.functional.scaled_grouped_mm support.
ref: https://docs.pytorch.org/docs/main/generated/torch.nn.functional.scaled_grouped_mm.html