TransformerEngine icon indicating copy to clipboard operation
TransformerEngine copied to clipboard

[PyTorch] Fix minor bug in computing num_gqa_groups_per_partition

Open knowlsie opened this issue 1 year ago • 3 comments

Currently the number of GQA groups per partition for DotProductAttention is computed using a default value, rather than the actual value computed earlier in the initializer, causing errors when tensor parallelism is enabled. This just fixes that.

knowlsie avatar Apr 13 '24 00:04 knowlsie

Thanks for spotting this! @knowlsie

ksivaman avatar Apr 15 '24 16:04 ksivaman

/te-ci pytorch

ksivaman avatar Apr 15 '24 16:04 ksivaman

Thanks for the quick review @ksivaman — I can’t merge without write access so feel free to merge whenever.

knowlsie avatar Apr 15 '24 16:04 knowlsie

Fixed in #1044

ksivaman avatar Jul 26 '24 15:07 ksivaman