TransformerEngine
TransformerEngine copied to clipboard
[PyTorch] Fix minor bug in computing num_gqa_groups_per_partition
Currently the number of GQA groups per partition for DotProductAttention is computed using a default value, rather than the actual value computed earlier in the initializer, causing errors when tensor parallelism is enabled. This just fixes that.
Thanks for spotting this! @knowlsie
/te-ci pytorch
Thanks for the quick review @ksivaman — I can’t merge without write access so feel free to merge whenever.
Fixed in #1044