mamba icon indicating copy to clipboard operation
mamba copied to clipboard

Encoder-decoder architecture

Open sentialx opened this issue 1 year ago • 6 comments

What would be the preferred way to make an encoder-decoder architecture with Mamba? I tried concatenating embeddings to decoder inputs with no luck. My use case is a diffusion model and the encoder would be used for conditioning

sentialx avatar Dec 24 '23 11:12 sentialx

It's still an open question how to do this with SSMs.

albertfgu avatar Dec 24 '23 19:12 albertfgu

What would be the preferred way to make an encoder-decoder architecture with Mamba? I tried concatenating embeddings to decoder inputs with no luck. My use case is a diffusion model and the encoder would be used for conditioning

Due to the inherent nature of the model I struggle to see how one might make an encoder-only component. Which makes sense as to why it's still an open question as @albertfgu stated. Hopefully this is a question that is solvable though, as until then I struggle to see how this model will translate to non-autoregressive (non-continuative) or multi-modal problem domains.

ElliottDyson avatar Feb 02 '24 22:02 ElliottDyson

@ElliottDyson Do you still have the code from your attempt? What went wrong; was the model just not converging?

@albertfgu To your knowledge, has any work been done on vector to vector Mamba or Mamba derivative models?

stanleyshly avatar Mar 31 '24 21:03 stanleyshly

@ElliottDyson Do you still have the code from your attempt? What went wrong; was the model just not converging?

@albertfgu To your knowledge, has any work been done on vector to vector Mamba or Mamba derivative models?

As for the code, I'm afraid not, never got that far yet due to other projects. As for Mamba derivatives, have a flick through the most recent pages of the papers on huggingface, there's been a few.

ElliottDyson avatar Mar 31 '24 22:03 ElliottDyson

Thank you for your response. I don't see any Mamba models that do vec2vec though, do you have a link to any?

stanleyshly avatar Mar 31 '24 23:03 stanleyshly

Thank you for your response. I don't see any Mamba models that do vec2vec though, do you have a link to any?

Something along the lines of this may work (sequence length of 1, fixed regression task): Mamba class forward method:

def forward(self, hidden_states):
    batch, seqlen, dim = hidden_states.shape
    assert seqlen == 1, "For regression, the input should be a single vector"
    
    # We do matmul and transpose BLH -> HBL at the same time
    xz = rearrange(
        self.in_proj.weight @ rearrange(hidden_states, "b l d -> d (b l)"),
        "d (b l) -> b d l",
        l=seqlen,
    )
    if self.in_proj.bias is not None:
        xz = xz + rearrange(self.in_proj.bias.to(dtype=xz.dtype), "d -> d 1")
    
    x, z = xz.chunk(2, dim=1)
    
    # Compute short convolution
    x = self.act(self.conv1d(x)[..., :seqlen])
    
    x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d"))  # (bl d)
    dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1)
    dt = self.dt_proj.weight @ dt.t()
    dt = rearrange(dt, "d (b l) -> b d l", l=seqlen)
    B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
    C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous()

    assert self.activation in ["silu", "swish"]
    y = selective_scan_fn(
        x,
        dt,
        A,
        B,
        C,
        self.D.float(),
        z=z,
        delta_bias=self.dt_proj.bias.float(),
        delta_softplus=True,
        return_last_state=False,
    )
    
    y = rearrange(y, "b d l -> b l d")
    out = self.out_proj(y)
    return out.squeeze(1)  # Remove the sequence dimension

Or if you meant sequence of vectors to sequence of vectors (original implementation but continuous instead of tokenised), then try this instead for the mamba forward block:

def forward(self, hidden_states):
    batch, seqlen, dim = hidden_states.shape
    
    # We do matmul and transpose BLH -> HBL at the same time
    xz = rearrange(
        self.in_proj.weight @ rearrange(hidden_states, "b l d -> d (b l)"),
        "d (b l) -> b d l",
        l=seqlen,
    )
    if self.in_proj.bias is not None:
        xz = xz + rearrange(self.in_proj.bias.to(dtype=xz.dtype), "d -> d 1")
    
    x, z = xz.chunk(2, dim=1)
    
    # Compute short convolution
    x = self.act(self.conv1d(x)[..., :seqlen])
    
    x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d"))  # (bl d)
    dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1)
    dt = self.dt_proj.weight @ dt.t()
    dt = rearrange(dt, "d (b l) -> b d l", l=seqlen)
    B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
    C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous()

    assert self.activation in ["silu", "swish"]
    y = selective_scan_fn(
        x,
        dt,
        A,
        B,
        C,
        self.D.float(),
        z=z,
        delta_bias=self.dt_proj.bias.float(),
        delta_softplus=True,
        return_last_state=False,
    )
    
    y = rearrange(y, "b d l -> b l d")
    out = self.out_proj(y)
    return out

Please let me know if this ends up working 🙂

Of course, you'll also need to change the input and output preprocessing to not use tokenisation, and change the loss function for training too.

ElliottDyson avatar Apr 01 '24 10:04 ElliottDyson