mamba icon indicating copy to clipboard operation
mamba copied to clipboard

Strict requirement of a **diagonal** `A`

Open buttercutter opened this issue 1 year ago • 4 comments
trafficstars

I have another question on the strict requirement of a diagonal A arising from some mathematical relationship with glorot/xavier initialization for A.

image image image image

Reference: [1] : Resurrecting Recurrent Neural Networks for Long Sequences [2] : HiPPO: Recurrent Memory with Optimal Polynomial Projections

buttercutter avatar Dec 20 '23 16:12 buttercutter

I'm sorry, what's your question?

albertfgu avatar Dec 20 '23 17:12 albertfgu

The question is: in papers, A needs to be diagonal, but in your code in https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba_simple.py, you have:

    # S4D real initialization
    A = repeat(
        torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device),
        "n -> d n",
        d=self.d_inner,
    ).contiguous()
    A_log = torch.log(A)  # Keep A_log in fp32
   A = -torch.exp(self.A_log.float())  # (d_inner, d_state)

Which is not diagonal.

   Thanks!

    Wei

houghtonweihu avatar Jan 11 '24 19:01 houghtonweihu

A is technically a batch of d_inner diagonal matrices, each of size d_state x d_state. Since it's diagonal, we don't need to store all the d_state x d_state entries, we just need to store d_state entries. So here we're storing (d_inner, d_state) entries.

tridao avatar Jan 11 '24 19:01 tridao

@tridao Thank you for clear explanation, and you may add this to your comments in the file so others can benefit it as well. Great work!

houghtonweihu avatar Jan 11 '24 20:01 houghtonweihu