mamba-minimal
mamba-minimal copied to clipboard
Create model.py2
another view import math import torch from torch import nn from torch.nn.utils import weight_norm
class Mamba(nn.Module): def init(self, d_model, d_state, n_layers, d_inner, dropout=0.1): super().init()
# We don't want to learn position embeddings.
# We'll do a simple positional encoding.
# Note that we divide by sqrt(d_model), which you'll find across other Transformer implementations,
# and serve the same purpose as with standard attention.
# `to(torch.float32)` is there only because this code is intended to be seamlessly used with mixed precision.
self.pos_enc = torch.arange(0, 64, dtype=torch.float32).view(1, -1).to(torch.float32) / math.sqrt(d_model)
layers = []
for _ in range(n_layers):
layers.append(MambaLayer(d_model, d_state, d_inner, dropout=dropout))
self.layers = nn.ModuleList(layers)
# Final dense layers.
self.fc = nn.Linear(d_model, 50257)
def forward(self, x, state_init=None):
# The input has `l` sequences of length `L` and `b` batch size.
# `x` has shape: `(l, b, L, d_model)`.
# We assume the first dimension is the `l` sequence one.
l, b, L, d = x.shape
if state_init is None:
state_init = torch.zeros(l, b, 1, d // 2, dtype=x.dtype, device=x.device)
x = x + self.pos_enc[:L, None]
states, outs = [], []
for layer in self.layers:
x, state = layer(x, state_init)
states.append(state)
# `outs` will eventually have shape `(l, b, L, d)`.
outs.append(x)
return self.fc(torch.cat(outs, dim=-1)), torch.cat(states, dim=-2)
class MambaLayer(nn.Module): def init(self, d_model, d_state, d_inner, dropout=0.1): super().init() d_model_half = d_model // 2
self.lin_A = nn.Linear(d_model, d_model_half)
self.lin_D = nn.Linear(d_model, d_model_half)
self.lin_in = nn.Linear(d_model, d_inner)
self.lin_B1 = nn.Linear(d_inner, d_model_half)
self.lin_B2 = nn.Linear(d_state, d_model_half)
self.lin_C = weight_norm(nn.Linear(d_model_half, d_model_half))
self.dropout = nn.Dropout(dropout)
def forward(self, x, state_init):
# We output both the state AND the transformed sequence (`x`).
# The `x` shape is expected to be `(l, b, L, d)`.
# The `state_init` shape is expected to be `(l, b, 1, n)`.
l, b, L, d = x.shape
d_model_half = d // 2
# We learned to use tanh activation for A and D.
A = torch.tanh(self.lin_A(x))
D = torch.tanh(self.lin_D(x))
a = self.dropout(self.lin_in(x))
b1 = self.lin_B1(a)
b2 = self.dropout(self.lin_B2(state_init))
B = b1 + b2
c = self.lin_C(self.dropout(A * B))
state = D * state_init + c[:, :, :, None]
# It looks like state_init might be off by one timestep from A, B, C, D, but this is
# not the case because we will start the loop on the 2nd timestep. It is perfectly
# consistent with the equations of Mamba (see [1] Algorithm 2).
# Intuitively, we also need to use `state_init` at time `t - 1` rather than `t` to compute
# `x_t`. Indeed, `state_t - 1` is a consequence of `x_t - 1` and `u_t - 1`.
# If we were to use `state_t`, this would be equivalent to having `δ_t = 1` instead of
# `δ_t = 0`, which is the case under the "zero-input" assumption made by the authors
# (see Equation (7) in [1]).
x = A * B + C
# We obtain a new state and a new output sequence `x`.
return x, state