mamba icon indicating copy to clipboard operation
mamba copied to clipboard

Batched generation with masking/padding

Open normster opened this issue 1 year ago • 7 comments

The instructions in the README on running lm-evaluation-harness set batch size > 1, and I would like to try batched generation in a standalone script.

Per this previous thread (https://github.com/state-spaces/mamba/issues/49#issuecomment-1850980748) it seems like standard attention masking/padding tokens are not supported yet, which should also mean batched generation with differently sized prompts is not currently possible, so how is lm-evaluation-harness is able to handle batch size > 1?

normster avatar Dec 19 '23 08:12 normster

The zero-shot evals only require evaluating likelihood (to pick among multiple choices) and not generation.

I don't think the current generation code supports batched generation of different lengths.

tridao avatar Dec 19 '23 08:12 tridao

That makes sense, thanks!

normster avatar Dec 19 '23 18:12 normster

@tridao do you think it would be feasible to implement masking by setting padded timesteps of the discretized A and B matrices to identity operators (i.e. all 1's for A and all 0's for B)? I tried implementing this in the naive selective_scan_ref and it seems to work:

from einops import rearrange, repeat
import torch
import torch.nn.functional as F

def selective_scan_ref(u, delta, A, B, C, D=None, z=None, delta_bias=None, mask=None, delta_softplus=False,
                      return_last_state=False):
    """
    u: r(B D L)
    delta: r(B D L)
    A: c(D N) or r(D N)
    B: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
    C: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
    D: r(D)
    z: r(B D L)
    delta_bias: r(D), fp32
+    mask: (B L)

    out: r(B D L)
    last_state (optional): r(B D dstate) or c(B D dstate)
    """
    dtype_in = u.dtype
    u = u.float()
    delta = delta.float()
    if delta_bias is not None:
        delta = delta + delta_bias[..., None].float()
    if delta_softplus:
        delta = F.softplus(delta)
    batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1]
    is_variable_B = B.dim() >= 3
    is_variable_C = C.dim() >= 3
    if A.is_complex():
        if is_variable_B:
            B = torch.view_as_complex(rearrange(B.float(), "... (L two) -> ... L two", two=2))
        if is_variable_C:
            C = torch.view_as_complex(rearrange(C.float(), "... (L two) -> ... L two", two=2))
    else:
        B = B.float()
        C = C.float()
    x = A.new_zeros((batch, dim, dstate))
    ys = []
    deltaA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A))
    if not is_variable_B:
        deltaB_u = torch.einsum('bdl,dn,bdl->bdln', delta, B, u)
    else:
        if B.dim() == 3:
            deltaB_u = torch.einsum('bdl,bnl,bdl->bdln', delta, B, u)
        else:
            B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1])
            deltaB_u = torch.einsum('bdl,bdnl,bdl->bdln', delta, B, u)
    if is_variable_C and C.dim() == 4:
        C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1])
    last_state = None
+    if mask is not None:
+        mask = mask[:, None, :, None].expand(-1, dim, -1, dstate) == 0
+        deltaA[mask] = 1
+        deltaB_u[mask] = 0
    for i in range(u.shape[2]):
        x = deltaA[:, :, i] * x + deltaB_u[:, :, i]
        if not is_variable_C:
            y = torch.einsum('bdn,dn->bd', x, C)
        else:
            if C.dim() == 3:
                y = torch.einsum('bdn,bn->bd', x, C[:, :, i])
            else:
                y = torch.einsum('bdn,bdn->bd', x, C[:, :, :, i])
        if i == u.shape[2] - 1:
            last_state = x
        if y.is_complex():
            y = y.real * 2
        ys.append(y)
    y = torch.stack(ys, dim=2) # (batch dim L)
    out = y if D is None else y + u * rearrange(D, "d -> d 1")
    if z is not None:
        out = out * F.silu(z)
    out = out.to(dtype=dtype_in)
    return out if not return_last_state else (out, last_state)

batch = 1
dim = 1536
L = 10
N = 16

u = torch.randn((batch, dim, L))
delta = torch.randn((batch, dim, L))
A = torch.randn((dim, N))
B = torch.randn((batch, N, L))
C = torch.randn((batch, N, L))
D = torch.randn(dim)
z = torch.randn((batch, dim, L))
delta_bias = torch.randn(dim)
mask = torch.tensor([[0] + [1] * (L - 1)])

out = selective_scan_ref(
    u,
    delta,
    A,
    B,
    C,
    D,
    z,
    delta_bias,
    delta_softplus=True,
)

out_masked = selective_scan_ref(
    u,
    delta,
    A,
    B,
    C,
    D,
    z,
    delta_bias,
    mask,
    delta_softplus=True,
)

out_true = selective_scan_ref(
    u[..., 1:],
    delta[..., 1:],
    A,
    B[..., 1:],
    C[..., 1:],
    D,
    z[..., 1:],
    delta_bias,
    mask[..., 1:],
    delta_softplus=True,
)

print("Should be False:", torch.allclose(out[:, :, 1:], out_true))
print("Should be True:", torch.allclose(out_masked[:, :, 1:], out_true))

But I'm not sure if there's a simple way to do this in the CUDA kernels.

normster avatar Dec 19 '23 22:12 normster

Yeah, that should work in principle. It might be easier to instead right-align (left-pad) all the prompts in your batch, and make sure that each layer zeros out its output in the padded regions (e.g. by passing in a mask as you did). Then you don't have to touch the internals of the SSM.

albertfgu avatar Dec 19 '23 22:12 albertfgu

Thanks, that makes sense. I didn't realize that deltaB_u was a linear transformation of x. I guess this approach doesn't technically handle internal pad tokens correctly but it works for left padded generation.

normster avatar Dec 19 '23 23:12 normster

Yeah, I think this relies on the idea that the only thing you care about is that the recurrent state is 0 in the padded region, so it's not affecting the relevant parts. Similar to your idea of setting Abar and Bbar appropriately to ensure that the hidden state gets transmitted through. In this case if the input $x=0$ and state $h=0$, then the state should remain 0.

albertfgu avatar Dec 20 '23 00:12 albertfgu

I tested this out in the slow path of Mamba.forward by masking twice (once before the causal conv1d and once before the selective scan):

class Mamba(nn.Module):
    ...
    def forward(self, hidden_states, mask=None, inference_params=None):
        ....
        if self.use_fast_path and inference_params is None:  # Doesn't support outputting the states
            ...
        else:
            x, z = xz.chunk(2, dim=1)

+            if mask is not None:
+                x = x * mask.unsqueeze(1)

            # Compute short convolution
            if conv_state is not None:
                conv_state.copy_(x[:, :, -self.d_conv :])  # Update state (B D W)
            if causal_conv1d_fn is None:
                x = self.act(self.conv1d(x)[..., :seqlen])
            else:
                assert self.activation in ["silu", "swish"]
                x = causal_conv1d_fn(
                    x,
                    rearrange(self.conv1d.weight, "d 1 w -> d w"),
                    self.conv1d.bias,
                    self.activation,
                )

+            if mask is not None:
+                x = x * mask.unsqueeze(1)

            # We're careful here about the layout, to avoid extra transposes.
            # We want dt to have d as the slowest moving dimension
            # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
            x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d"))  # (bl d)
            dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1)
            ...

Testing with this script:

import torch
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel

model = MambaLMHeadModel.from_pretrained('/data/norman_mu/models/mamba-130m').to('cuda')
input_ids = torch.randint(1, 1000, (1, 1024)).to('cuda')
input_ids_padded = torch.cat([torch.zeros_like(input_ids[:, [0]]), input_ids], dim=1)
attention_mask = torch.cat([torch.zeros_like(input_ids[:, [0]]), torch.ones_like(input_ids)], dim=1)

out = model(input_ids_padded).logits.detach().cpu()
out_padded = model(input_ids_padded, attention_mask).logits.detach().cpu()
out_true = model(input_ids).logits.detach().cpu()

print("max L2 error:", (out_true - out[:, 1:]).norm(dim=-1).max())
print("max L2 errors (padded):", (out_true - out_padded[:, 1:]).norm(dim=-1).max())

This prints:

max L2 error: tensor(24580.3848)
max L2 errors (padded): tensor(0.5131)

which isn't perfect but also doesn't seem too bad for 50k dim logits. I'm guessing this is due to the causal_conv1d leaking information from the pad token in index 0. Does causal_conv1d not use zero padding?

normster avatar Dec 20 '23 02:12 normster