lorax icon indicating copy to clipboard operation
lorax copied to clipboard

Efficient implementation of all_reduce and all_gather for collect_lora_a

Open hayleyhu opened this issue 1 year ago • 2 comments
trafficstars

Feature request

We noticed that collect_lora_a() is calling all_gather and all_reduce every time.

Do you think you could give a more efficient implementation of this soon? If not, could you give us ideas to implement this on our own?

Thanks.

Motivation

We observe significant time spent on all_reduce and all_gather during lora adapter inference: 60% of first token communication time "with adapter" is spent on all_gather operations while without adapters it's only 0.01% (and the communication is 55% of GPU operations time).

And thus the first token latency is making a big difference on user experience.

Your contribution

We can do the implementation of the advice is clear.

hayleyhu avatar Mar 26 '24 23:03 hayleyhu

Hey @hayleyhu,

Thanks for submitting this issue. To clarify: are you running with the SGMV kernel disabled (latest version of LoRAX should log whether it is enabled or not during initialization)?

When using SGMV, we only run the all_gather and all_reduce once per layer, which is optimal. It is true that when SGMV is disabled we do it once per adapter, which is much slower, but this is primarily a fallback code path that most users shouldn't be encountering.

tgaddair avatar Mar 28 '24 21:03 tgaddair

SGMV is enabled. Still the all_gather and all_reduce time is very high.

hayleyhu avatar Mar 29 '24 23:03 hayleyhu