mamba
mamba copied to clipboard
Implement bi-directionality
Edit:
- [x] Implement bi-directionality by applying
Mambamodule twice: (1) to the forward sequence and (2) to the backward sequence. - [x] Implement ~3~ 2 strategies for combining forward / backward Mamba hidden states:
add: Add the states.- ~
concat: Concatenate the states. This doubles the hidden dimension,d_model, which also prevents weight tying betweenembeddingandlm_headweights.~ ew_multiply: perform element-wise multiplication between the states.
What if the sequences have paddings? E.g. Input is [1 2 3 0 0 0] So flipped input would be [0 0 0 3 2 1]. Shouldn't it be [3 2 1 0 0 0]?
@sentialx , agreed. That's a good catch.
how the speed compares to uni-directional?
how the speed compares to uni-directional?
@jimmieliu, it's about 2x
@yair-schiff I am just curious, did you solve the
What if the sequences have paddings? E.g. Input is [1 2 3 0 0 0] So flipped input would be [0 0 0 3 2 1]. Shouldn't it be [3 2 1 0 0 0]?
Just curious, is this problem solved?
I came up with a solution to the padding issue. Say a tensor [1,2,3,0,0], where 0 is the padding token. We flip it to get [0,0,1,2,3], pass it to the network and flip it back. Therefore, the flipped tensor information matches the original tensor order as we apply double flips.
given: x
out = x + f(x.flip()).flip()
I came up with a solution to the padding issue. Say a tensor [1,2,3,0,0], where 0 is the padding token. We flip it to get [0,0,1,2,3], pass it to the network and flip it back. Therefore, the flipped tensor information matches the original tensor order as we apply double flips.
given: x out = x + f(x.flip()).flip()
Hi, Your approach is clever! But I have a question: if you flip the input to [0,0,1,2,3], does the padding in front of it affect sequence hidden features learning? i.e., does it produce a different result(bad repersentation of sequence) than the input of [3,2,1,0,0]? I don't know enough about it, could you possibly give me some guidance? This will help me a lot. Thank you very much!
@xuanwuji well, you can remove the leading paddings by shifting each row of x before flipping x. As for its effect, since the hidden state is initialized with 0, it should still be filled with 0 after scanning through the paddings. So, those padding shouldn't have any effect on the result. However, you can use the following function just to be sure.
def flip_padded_hidden_states(hidden_states, seq_lens):
batch_size, seq_len, hidden_dim = hidden_states.shape
indices = torch.arange(batch_size * seq_len, device=hidden_states.device).reshape(
batch_size, seq_len
)
indices_offset = seq_len - seq_lens
indices = (indices - indices_offset.unsqueeze(1)) % (seq_len * batch_size)
indices = indices.flip(1)
return hidden_states.reshape(batch_size * seq_len, hidden_dim)[indices]
To check the effect of paddings:
import torch
from mamba_ssm import Mamba2, Mamba
from torch.nn import functional as F
batch, length, dim = 2, 64, 16
model = Mamba(
d_model=dim, # Model dimension d_model
d_state=16, # SSM state expansion factor
d_conv=4, # Local convolution width
expand=2, # Block expansion factor
).to("cuda")
x = torch.randn(batch, length, dim).to("cuda")
padded_x = F.pad(x, (0,0, 4,0))
y = model(x)
padded_y = model(padded_x)
unpadded_y = padded_y[:,4:]
print(f'Output max diff: {(unpadded_y - y).abs().max().item()}')
print(f'Output mean diff: {(unpadded_y - y).abs().mean().item()}')
However, these errors do stack after multiple layers, so you should use the flip_padded_hidden_states function just to be certain.