mamba
mamba copied to clipboard
Batched generation with masking/padding
The instructions in the README on running lm-evaluation-harness
set batch size > 1, and I would like to try batched generation in a standalone script.
Per this previous thread (https://github.com/state-spaces/mamba/issues/49#issuecomment-1850980748) it seems like standard attention masking/padding tokens are not supported yet, which should also mean batched generation with differently sized prompts is not currently possible, so how is lm-evaluation-harness
is able to handle batch size > 1?
The zero-shot evals only require evaluating likelihood (to pick among multiple choices) and not generation.
I don't think the current generation code supports batched generation of different lengths.
That makes sense, thanks!
@tridao do you think it would be feasible to implement masking by setting padded timesteps of the discretized A and B matrices to identity operators (i.e. all 1's for A and all 0's for B)? I tried implementing this in the naive selective_scan_ref
and it seems to work:
from einops import rearrange, repeat
import torch
import torch.nn.functional as F
def selective_scan_ref(u, delta, A, B, C, D=None, z=None, delta_bias=None, mask=None, delta_softplus=False,
return_last_state=False):
"""
u: r(B D L)
delta: r(B D L)
A: c(D N) or r(D N)
B: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
C: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
D: r(D)
z: r(B D L)
delta_bias: r(D), fp32
+ mask: (B L)
out: r(B D L)
last_state (optional): r(B D dstate) or c(B D dstate)
"""
dtype_in = u.dtype
u = u.float()
delta = delta.float()
if delta_bias is not None:
delta = delta + delta_bias[..., None].float()
if delta_softplus:
delta = F.softplus(delta)
batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1]
is_variable_B = B.dim() >= 3
is_variable_C = C.dim() >= 3
if A.is_complex():
if is_variable_B:
B = torch.view_as_complex(rearrange(B.float(), "... (L two) -> ... L two", two=2))
if is_variable_C:
C = torch.view_as_complex(rearrange(C.float(), "... (L two) -> ... L two", two=2))
else:
B = B.float()
C = C.float()
x = A.new_zeros((batch, dim, dstate))
ys = []
deltaA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A))
if not is_variable_B:
deltaB_u = torch.einsum('bdl,dn,bdl->bdln', delta, B, u)
else:
if B.dim() == 3:
deltaB_u = torch.einsum('bdl,bnl,bdl->bdln', delta, B, u)
else:
B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1])
deltaB_u = torch.einsum('bdl,bdnl,bdl->bdln', delta, B, u)
if is_variable_C and C.dim() == 4:
C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1])
last_state = None
+ if mask is not None:
+ mask = mask[:, None, :, None].expand(-1, dim, -1, dstate) == 0
+ deltaA[mask] = 1
+ deltaB_u[mask] = 0
for i in range(u.shape[2]):
x = deltaA[:, :, i] * x + deltaB_u[:, :, i]
if not is_variable_C:
y = torch.einsum('bdn,dn->bd', x, C)
else:
if C.dim() == 3:
y = torch.einsum('bdn,bn->bd', x, C[:, :, i])
else:
y = torch.einsum('bdn,bdn->bd', x, C[:, :, :, i])
if i == u.shape[2] - 1:
last_state = x
if y.is_complex():
y = y.real * 2
ys.append(y)
y = torch.stack(ys, dim=2) # (batch dim L)
out = y if D is None else y + u * rearrange(D, "d -> d 1")
if z is not None:
out = out * F.silu(z)
out = out.to(dtype=dtype_in)
return out if not return_last_state else (out, last_state)
batch = 1
dim = 1536
L = 10
N = 16
u = torch.randn((batch, dim, L))
delta = torch.randn((batch, dim, L))
A = torch.randn((dim, N))
B = torch.randn((batch, N, L))
C = torch.randn((batch, N, L))
D = torch.randn(dim)
z = torch.randn((batch, dim, L))
delta_bias = torch.randn(dim)
mask = torch.tensor([[0] + [1] * (L - 1)])
out = selective_scan_ref(
u,
delta,
A,
B,
C,
D,
z,
delta_bias,
delta_softplus=True,
)
out_masked = selective_scan_ref(
u,
delta,
A,
B,
C,
D,
z,
delta_bias,
mask,
delta_softplus=True,
)
out_true = selective_scan_ref(
u[..., 1:],
delta[..., 1:],
A,
B[..., 1:],
C[..., 1:],
D,
z[..., 1:],
delta_bias,
mask[..., 1:],
delta_softplus=True,
)
print("Should be False:", torch.allclose(out[:, :, 1:], out_true))
print("Should be True:", torch.allclose(out_masked[:, :, 1:], out_true))
But I'm not sure if there's a simple way to do this in the CUDA kernels.
Yeah, that should work in principle. It might be easier to instead right-align (left-pad) all the prompts in your batch, and make sure that each layer zeros out its output in the padded regions (e.g. by passing in a mask as you did). Then you don't have to touch the internals of the SSM.
Thanks, that makes sense. I didn't realize that deltaB_u
was a linear transformation of x
. I guess this approach doesn't technically handle internal pad tokens correctly but it works for left padded generation.
Yeah, I think this relies on the idea that the only thing you care about is that the recurrent state is 0 in the padded region, so it's not affecting the relevant parts. Similar to your idea of setting Abar and Bbar appropriately to ensure that the hidden state gets transmitted through. In this case if the input $x=0$ and state $h=0$, then the state should remain 0.
I tested this out in the slow path of Mamba.forward
by masking twice (once before the causal conv1d and once before the selective scan):
class Mamba(nn.Module):
...
def forward(self, hidden_states, mask=None, inference_params=None):
....
if self.use_fast_path and inference_params is None: # Doesn't support outputting the states
...
else:
x, z = xz.chunk(2, dim=1)
+ if mask is not None:
+ x = x * mask.unsqueeze(1)
# Compute short convolution
if conv_state is not None:
conv_state.copy_(x[:, :, -self.d_conv :]) # Update state (B D W)
if causal_conv1d_fn is None:
x = self.act(self.conv1d(x)[..., :seqlen])
else:
assert self.activation in ["silu", "swish"]
x = causal_conv1d_fn(
x,
rearrange(self.conv1d.weight, "d 1 w -> d w"),
self.conv1d.bias,
self.activation,
)
+ if mask is not None:
+ x = x * mask.unsqueeze(1)
# We're careful here about the layout, to avoid extra transposes.
# We want dt to have d as the slowest moving dimension
# and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d")) # (bl d)
dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1)
...
Testing with this script:
import torch
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
model = MambaLMHeadModel.from_pretrained('/data/norman_mu/models/mamba-130m').to('cuda')
input_ids = torch.randint(1, 1000, (1, 1024)).to('cuda')
input_ids_padded = torch.cat([torch.zeros_like(input_ids[:, [0]]), input_ids], dim=1)
attention_mask = torch.cat([torch.zeros_like(input_ids[:, [0]]), torch.ones_like(input_ids)], dim=1)
out = model(input_ids_padded).logits.detach().cpu()
out_padded = model(input_ids_padded, attention_mask).logits.detach().cpu()
out_true = model(input_ids).logits.detach().cpu()
print("max L2 error:", (out_true - out[:, 1:]).norm(dim=-1).max())
print("max L2 errors (padded):", (out_true - out_padded[:, 1:]).norm(dim=-1).max())
This prints:
max L2 error: tensor(24580.3848)
max L2 errors (padded): tensor(0.5131)
which isn't perfect but also doesn't seem too bad for 50k dim logits. I'm guessing this is due to the causal_conv1d leaking information from the pad token in index 0. Does causal_conv1d not use zero padding?