mamba
mamba copied to clipboard
exp the log of A?
what's the purpose of taking the log of A, and then exponentiating?
https://github.com/state-spaces/mamba/blob/6dbfc4553a98c81e1e93b8fd2d5abf387f5c09ee/mamba_ssm/modules/mamba_simple.py#L146
So that we use self.A_log as parameter and there's no restriction. If we parameterize A as a parameter directly it's harder to constrain A to be positive (which is what we want).
For context, this parameterization has been adopted from earlier work on structured SSMs: see S4 and related papers.
Building upon the reasoning behind the parameter range as mentioned by Tri, it's important to note that the advantages of this reparameterization extend beyond mere numerical range adjustments. While projecting parameters to a positive (or negative) range could serve as a workaround, this reparameterization strategy offers additional benefits.
[Theory explanation] Let W be the trainable weight and A be the realized weight.
- Setting A = W and recurrent models will have difficulty learning long-term memory. (This difficulty is attributed to several phenomena, including catastrophic forgetting, exponential decay, and what's often referred to as the curse of memory)
- Setting A = -exp(W) and state-space models can learn (or in the paper terms, stably approximate) targets with long-term memory.
Having this reparameterization of A is important for the long-term memory learning.
Reference can be found in: https://arxiv.org/abs/2311.14495
Besides the notion of approximation, this paper also mention the gradient scale under reparameterization is usually milder across different decay patterns. (Other reparameterization such as A=-softplus(W) has similar benefits)
This exponential has been discussed in works such as:
- ExpRNN: https://arxiv.org/abs/1901.08428
- S4: https://arxiv.org/abs/2206.11893