candle icon indicating copy to clipboard operation
candle copied to clipboard

How to use `topk`?

Open EricLBuehler opened this issue 1 year ago • 6 comments
trafficstars

I am trying to use topk to implement X-LoRA in Candle, and want to perform topk in the last dimension. Specifically, I need the indices return value (as returned by torch.topk).

These indices will either be used to creaste a mask to zero out all the values which are not in the topk, and/or used to apply scalings on the nonzero values. This is a may be hard to understand, as such please see this snippet from our X-LoRA library.

Is there a way to implement this with the current Candle functions, or is this planned to be implemented as a function?


After looking at the Mixtral MoE selection implementation, I cannot really understand it:

https://github.com/huggingface/candle/blob/3144150b8d1b80b2c6b469dcab5b717598f0a458/candle-transformers/src/models/mixtral.rs#L302-L323

How does this work? Thanks!

EricLBuehler avatar Mar 30 '24 20:03 EricLBuehler

Hey @EricLBuehler - Not sure about the roadmap and you may already have this solved, but I think I can help explain the above code. If I'm understanding the above correctly, it's probably not a great example of an efficient top_k implementation.

This is what I think is roughly happening; routing weights should be fairly small (the shape should be something like [batch_size * seq_len, n_experts], where n_experts = 8 for mixtral). We iterate over each batch/seq_len (line 308) and pick which batch/seq_len that expert(s) should route to (line 312 for the idx). To pick the token's expert(s), we sort dst (which is just expert indices) by the weight of the route (line 310) and grab only the num_experts_per_tok (line 312 and 318).

I say it's probably not a great example of an efficient top_k, since each routing-weight is fully sorted (line 310). Fine for here, because n_experts is small. For larger vectors, it's probably not super efficient since you don't need to whole vector sorted, you just need the top-k sorted.

Definitely would be a useful feature to add! Happy to help implement it.

gregszumel avatar Apr 25 '24 12:04 gregszumel

@gregszumel, thanks for the explanation. I would love to see this added to Candle. If you want to contribute this to mistral.rs please feel free!

EricLBuehler avatar Apr 26 '24 14:04 EricLBuehler

It's certainly not an efficient top_k as indeed we fully sort the array, but the worst part is more that the operation takes place on the cpu so this triggers some synchronization point. I'm putting together #2132 that should help doing this fully on the gpu, it will still sort the whole array but actually on gpus it's not necessarily that slower than a topk as the sort is done in parallel (also the number of experts is very small, 8, so this wouldn't be much of an issue even on the cpu).

LaurentMazare avatar Apr 27 '24 12:04 LaurentMazare

@LaurentMazare, thanks! I saw that PR and am very excited for it to be merged.

EricLBuehler avatar Apr 27 '24 14:04 EricLBuehler

@EricLBuehler Looks like it got merged - thanks at @LaurentMazare!

I hope to find some time this week to implement top-k; I'm thinking it's home is probably in either in mistral.rs or candle-ext, but the argsort is probably a great workaround for now.

gregszumel avatar Apr 28 '24 22:04 gregszumel

@gregszumel, that sounds great! If you decide to contribute it to mistral.rs, that would be much appreciated.

EricLBuehler avatar Apr 28 '24 22:04 EricLBuehler

For future reference, here's the implementation: https://github.com/EricLBuehler/mistral.rs/blob/6aec940499be1cf72c628f7ddaa8b3e59bcb4fda/mistralrs-core/src/ops.rs#L482-L504

EricLBuehler avatar Jul 23 '24 02:07 EricLBuehler