mamba icon indicating copy to clipboard operation
mamba copied to clipboard

Implement bi-directionality

Open yair-schiff opened this issue 1 year ago • 8 comments

Edit:

  • [x] Implement bi-directionality by applying Mamba module twice: (1) to the forward sequence and (2) to the backward sequence.
  • [x] Implement ~3~ 2 strategies for combining forward / backward Mamba hidden states:
    1. add: Add the states.
    2. ~concat: Concatenate the states. This doubles the hidden dimension,d_model, which also prevents weight tying between embedding and lm_head weights.~
    3. ew_multiply: perform element-wise multiplication between the states.

yair-schiff avatar Dec 13 '23 04:12 yair-schiff

What if the sequences have paddings? E.g. Input is [1 2 3 0 0 0] So flipped input would be [0 0 0 3 2 1]. Shouldn't it be [3 2 1 0 0 0]?

sentialx avatar Dec 24 '23 11:12 sentialx

@sentialx , agreed. That's a good catch.

yair-schiff avatar Dec 24 '23 16:12 yair-schiff

how the speed compares to uni-directional?

jimmieliu avatar Jan 02 '24 08:01 jimmieliu

how the speed compares to uni-directional?

@jimmieliu, it's about 2x

yair-schiff avatar Jan 03 '24 15:01 yair-schiff

@yair-schiff I am just curious, did you solve the

What if the sequences have paddings? E.g. Input is [1 2 3 0 0 0] So flipped input would be [0 0 0 3 2 1]. Shouldn't it be [3 2 1 0 0 0]?

Just curious, is this problem solved?

pengzhangzhi avatar Jan 24 '24 02:01 pengzhangzhi

I came up with a solution to the padding issue. Say a tensor [1,2,3,0,0], where 0 is the padding token. We flip it to get [0,0,1,2,3], pass it to the network and flip it back. Therefore, the flipped tensor information matches the original tensor order as we apply double flips.

given: x
out = x + f(x.flip()).flip()

pengzhangzhi avatar Jan 25 '24 19:01 pengzhangzhi

I came up with a solution to the padding issue. Say a tensor [1,2,3,0,0], where 0 is the padding token. We flip it to get [0,0,1,2,3], pass it to the network and flip it back. Therefore, the flipped tensor information matches the original tensor order as we apply double flips.

given: x
out = x + f(x.flip()).flip()

Hi, Your approach is clever! But I have a question: if you flip the input to [0,0,1,2,3], does the padding in front of it affect sequence hidden features learning? i.e., does it produce a different result(bad repersentation of sequence) than the input of [3,2,1,0,0]? I don't know enough about it, could you possibly give me some guidance? This will help me a lot. Thank you very much!

xuanwuji avatar Jul 13 '24 03:07 xuanwuji

@xuanwuji well, you can remove the leading paddings by shifting each row of x before flipping x. As for its effect, since the hidden state is initialized with 0, it should still be filled with 0 after scanning through the paddings. So, those padding shouldn't have any effect on the result. However, you can use the following function just to be sure.

def flip_padded_hidden_states(hidden_states, seq_lens):
    batch_size, seq_len, hidden_dim = hidden_states.shape

    indices = torch.arange(batch_size * seq_len, device=hidden_states.device).reshape(
        batch_size, seq_len
    )

    indices_offset = seq_len - seq_lens

    indices = (indices - indices_offset.unsqueeze(1)) % (seq_len * batch_size)

    indices = indices.flip(1)

    return hidden_states.reshape(batch_size * seq_len, hidden_dim)[indices]

To check the effect of paddings:

import torch
from mamba_ssm import Mamba2, Mamba
from torch.nn import functional as F

batch, length, dim = 2, 64, 16

model = Mamba(
    d_model=dim, # Model dimension d_model
    d_state=16,  # SSM state expansion factor
    d_conv=4,    # Local convolution width
    expand=2,    # Block expansion factor
).to("cuda")

x = torch.randn(batch, length, dim).to("cuda")
padded_x = F.pad(x, (0,0, 4,0))

y = model(x)
padded_y = model(padded_x)

unpadded_y = padded_y[:,4:]

print(f'Output max diff: {(unpadded_y - y).abs().max().item()}')
print(f'Output mean diff: {(unpadded_y - y).abs().mean().item()}')

However, these errors do stack after multiple layers, so you should use the flip_padded_hidden_states function just to be certain.

Museum7432 avatar Jul 14 '24 10:07 Museum7432