[Question] Use tutel with attention based experts
Hi, I am currently using tutel with custom FF based experts. I use an implementation similar to main/tutel/examples/helloworld_from_scratch.py#L59. While this works nice, I also want to use tutel with attention based experts.
In this setting moe.top_k_routing and moe.fast_encode expect a tensor of shape (T, C), or if using a batch size > 1, a flattened tensor of shape (B * T, C), where T is the sequence lenght. This is fine for experts that act independently on each token, such as linear layers but doesn't work for attention layers, as without e.g. providing an appropriate attention mask, there would be cross-batch attention, which I definetly do not want.
Is there a way to use tutel to use top_k_routing and fast_encode on tensors of shape (B, T, C)? I already tried to run it B times (on the individual batches) and concatenate afterwards together with padding, but thats crazy slow in backpropagation. I also tried to get an attention mask, that tells me which tokens in the resulting (<(B * T), C) output originally belonged to which batch, using the output of top_k_routing by following the logic in fast_encode, but I just don't fully get the operations done there.
Is there a simple way that I overlook or has someone else already implemented something similar?
Hello. I am trying to understand your requirements. Are you looking for extra information from the result of https://github.com/microsoft/Tutel/blob/036614c68d058957bbf02ec4392b945993957902/tutel/examples/helloworld_from_scratch.py#L59C23-L59C41 that tells which token belongs to which sequence?
In additional, is this information needed only "before line 61 or after line 63"? If so, no extra communication is needed. Otherwise, i.e. you want to know the token belongings during expert computation https://github.com/microsoft/Tutel/blob/036614c68d058957bbf02ec4392b945993957902/tutel/examples/helloworld_from_scratch.py#L62, then extra communication has to be involved.
As I am using only single GPU training, I guess no extra communication is needed at all (can I even omit the all_to_all in my code?)
In the helloworld_from_scratch example the batch size is 1, but I have seen in https://github.com/microsoft/Tutel/blob/036614c68d058957bbf02ec4392b945993957902/tutel/impls/moe_layer.py#L264 that for multi-batch scenarios the batch and token dimensions are flattened to be handled by further operations, so I do the same in my "from scratch" implementation.
Assume that my input tensor has shape (2, 2048, 320), i.e. (B, T, C). In order to make top_k_routing work I have to flatten this tensor to (B * T, C), i.e. (4096, 320), otherwise an error is thrown. The resulting crit output together with fast_encode then gives me a tensor y with e.g. shape (1037, 320) but in general of shape (<(B * T), C). This is nice for my MoE as it reduces the number of required computations and is in general perfectly fine for my Feed Forward MoE, where each token is treated separately.
But I cannot use this output tensor y for an expert with attentions, as e.g. a self attention would now attend any of the 1037 tokens to each other, effectively being an attention across tokens from multiple batches.
As I see it there are two possible ways to go from here:
- Make
top_k_routingandfast_encode"batch-aware", i.e. supporting inputs of shape (2, 2048, 320) and outputting a tensorywith e.g. shape (2, 1185, 320), that I can use in my expert. - Somehow constructing or better getting as return from one of the two functions a binary attention mask of shape (1037, 1037) that tells me exactly which tokens belonged to the same batch and therefore should attend to each other. Or a binary tensor of shape (1037, 2) that tells me to which batch a token originally belonged before
fast_encode, then I can construct the attention mask myself.
For point-1, can I assume that your requirement is to have an interface like batched_fast_encode, which runs multiple fast_encode/decode independently?
For point-2 in your last comment, does mask of shape (1037, 1037) just work for batch > 2? While for the binary tensor of shape (1037, 2), why it is not int32 tensor of shape [1037,] where tensor[i] = batchIdx, for example, either 0 or 1 if bsz == 2. (BTW, this option seems to require the minimum changes.)
Excatly, a batched_fast_encode (+ batch support for top_k_routing as this currently throws an error if ndim>2) is exactly what I have in mind. I already implemented a naive method, where I ran a for loop over each batch, used fast_encode for each batch and then ran F.pad(...) on each result to concatenate them again. In theory this gives me a tensor of shape (batch_size, < T, C), but this I very slow and essentially unusable.
You are right, (1037, 1037) is would not work for batch_size>2, so forget about that. An int32 tensor of shape (1037,) telling the original batch id would be totally fine. I guess this information should be extractable from the output of top_k_routing, even if the tensor has been flattened before that, but I just cannot make sense of the different lists and tensors that top_k_routing returns.
This PR extends 2 helper functions that can deal with your requirement. It will take about 1 day to get merged once the review approval is passed.
This line may be exactly what you need: https://github.com/microsoft/Tutel/blob/main/tutel/examples/helloworld_from_scratch.py#L72
Thanks a lot for your help and time! That's exactly what I need.
@ghostplant Hey again. The functions you provided work like a charm, so thanks again. However, there is one exception where the get_reversed_sample_ids throws the following error (technically, this is not an error directly stopping the program, but the resulting tensor is inaccesible):
/pytorch/aten/src/ATen/native/cuda/ScatterGatherKernel.cu:144: operator(): block: [1,0,0], thread: [62,0,0] Assertion "idx_dim >= 0 && idx_dim < index_size && "index out of bounds"" failed.
This happens whenever 0.0<capacity_factor<1.0. Then, the dimension of the result in get_reversed_sample_ids is smaller then the largest offset, so there is an index out of bounds thrown in scatter.
I know that choosing a capacity factor smaller 1 is quite a niche use case, but maybe you see a quick way to fix this.
Yes, any non dropless capacity may result in location overflow. The issue should be resolved by this fix: https://github.com/microsoft/Tutel/pull/289
Oh yes, that works now. Thank you very much!