x-transformers icon indicating copy to clipboard operation
x-transformers copied to clipboard

[Bug] XL-recurrence with AlibiPositionalBias and mems not working correctly

Open pfeatherstone opened this issue 1 year ago • 13 comments

This is a similar issue to https://github.com/lucidrains/x-transformers/issues/223 but with Alibi.

So, I am trying to do XL-recurrence with:

  • AlibiPositionalBias
  • attn_num_mem_kv > 0
  • mems, mem_masks and return_mems

The repro is:

lm = ContinuousTransformerWrapper(
    dim_in              = 2,
    dim_out             = 36,
    max_seq_len         = 0,
    max_mem_len         = 100,
    attn_layers = Decoder (
        dim             = 512,
        depth           = 4,
        heads           = 4,

        disable_abs_pos_emb = True,
        alibi_pos_bias  = True,
        alibi_num_heads = 2,

        attn_flash      = True,
        attn_num_mem_kv = 20
    )
)

B, M, D, depth = 1, lm.max_mem_len, lm.attn_layers.dim, lm.attn_layers.depth

x           = torch.randn(B, 1024, 2)
length      = torch.randint(100, x.shape[1], size=(x.shape[0],))
mask        = torch.arange(x.shape[1]).unsqueeze(0) < length.unsqueeze(-1)
mems        = [torch.randn(x.shape[0], M, D) for _ in range(depth)]
mem_masks   = [torch.zeros(x.shape[0], M).bool() for _ in range(depth)]

out1, mems1 = lm(x, mask=mask, return_mems=True)
out2, mems2 = lm(x, mask=mask, mems=mems, mem_masks=mem_masks, return_mems=True)
torch.testing.assert_close(out1, out2)
for m1, m2 in zip(mems1, mems2):
    torch.testing.assert_close(m1, m2)

I imagine the issue and fix is similar to RoPE.

pfeatherstone avatar Feb 22 '24 08:02 pfeatherstone

@pfeatherstone yea i could make it work

however, i think alibi is really bad and probably should be removed. i would never use it for any serious model

lucidrains avatar Feb 23 '24 15:02 lucidrains

@pfeatherstone maybe i'll just make it all work as a personal challenge. i can also fix dynamic pos bias in the presence of memory as well

lucidrains avatar Feb 23 '24 15:02 lucidrains

Basically i found rotary positional embedding to not length-extrapolate at all. Like, it's really bad. I'm doing some tests with XPOS and though I haven't got a complete model yet (still training), it looks a bit better. However I'm nervous about limiting the context length.

pfeatherstone avatar Feb 23 '24 16:02 pfeatherstone

@pfeatherstone yea i could make it work

however, i think alibi is really bad and probably should be removed. i would never use it for any serious model

is this based on personal tests? According to the paper it's the best thing since sliced bread

pfeatherstone avatar Feb 23 '24 16:02 pfeatherstone

Basically i found rotary positional embedding to not length-extrapolate at all. Like, it's really bad. I'm doing some tests with XPOS and though I haven't got a complete model yet (still training), it looks a bit better. However I'm nervous about limiting the context length.

that's well known, i think i even mention it in the readme. however, there's a lot of research going into fine tuning trained rotary models to longer context, so it is not a big deal

lucidrains avatar Feb 23 '24 16:02 lucidrains

i wouldn't use xpos either.. it suffers from the same issues as alibi. i really should start removing features i no longer believe in

lucidrains avatar Feb 23 '24 16:02 lucidrains

@pfeatherstone yea i could make it work however, i think alibi is really bad and probably should be removed. i would never use it for any serious model

is this based on personal tests? According to the paper it's the best thing since sliced bread

which paper?

lucidrains avatar Feb 23 '24 16:02 lucidrains

https://ofir.io/train_short_test_long.pdf, the one you reference in your readme. I have to admit I haven't read it in great detail but they suggest AliBI is great.

pfeatherstone avatar Feb 23 '24 16:02 pfeatherstone

Basically I need a positional embedding that length-extrapolates well, works with memories, and flash attention. Do you have any suggestions?

pfeatherstone avatar Feb 23 '24 16:02 pfeatherstone

@pfeatherstone that's from the author of alibi. of course they would say it is great

lucidrains avatar Feb 23 '24 16:02 lucidrains

Basically I need a positional embedding that length-extrapolates well, works with memories, and flash attention. Do you have any suggestions?

these days, i would stick with rotary, given the amount of research now going into it

curriculum learn to longer sequence lengths while tuning the rotary theta value (and whatever new tricks recent papers have discovered)

lucidrains avatar Feb 23 '24 16:02 lucidrains

What do you mean by curriculum learn to longer sequence lengths? Sorry if my questions are dumb.

pfeatherstone avatar Feb 23 '24 16:02 pfeatherstone

@pfeatherstone ah, curriculum learning is just a fancy way of saying making training increasingly harder over time, like how you design a curriculum for a student. so start with a small sequence length and slowly increase to your desired length

lucidrains avatar Feb 23 '24 16:02 lucidrains

@pfeatherstone do you want to see if 1.28.0 fixes the issue?

lucidrains avatar Apr 29 '24 01:04 lucidrains

@pfeatherstone do you want to see if 1.28.0 fixes the issue?

I think that update broke something, I got this error while training the transformer in MeshGPT.

Was it the update or did the attention args started to kick in? attn_kwargs: dict = dict( ff_glu = True, num_mem_kv = 4 ), https://github.com/lucidrains/meshgpt-pytorch/blob/main/meshgpt_pytorch/meshgpt_pytorch.py#L1011C7-L1014C11

File /opt/conda/lib/python3.10/site-packages/x_transformers/x_transformers.py:950, in Attention.forward(self, x, context, mask, context_mask, attn_mask, rel_pos, rotary_pos_emb, prev_attn, mem, mem_mask, return_intermediates, cache)
    947 # append with no bias for memory key / values
    949 if has_mem_kv:
--> 950     attn_bias = pad_at_dim(attn_bias, (self.num_mem_kv, 0), value = 0.)
    952 # attention is all we need
    954 out, intermediates = self.attend(
    955     q, k, v,
    956     mask = final_attn_mask,
    957     attn_bias = attn_bias,
    958     prev_attn = prev_attn
    959 )

File /opt/conda/lib/python3.10/site-packages/x_transformers/x_transformers.py:97, in pad_at_dim(t, pad, dim, value)
     95 dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1)
     96 zeros = ((0, 0) * dims_from_right)
---> 97 return F.pad(t, (*zeros, *pad), value = value)

TypeError: pad(): argument 'input' (position 1) must be Tensor, not NoneType

MarcusLoppe avatar Apr 29 '24 03:04 MarcusLoppe

@MarcusLoppe oh hey Marcus! fancy seeing you here 😄

should be fixed!

lucidrains avatar Apr 29 '24 03:04 lucidrains

@MarcusLoppe oh hey Marcus! fancy seeing you here 😄

should be fixed!

I get around 😄 Trying to find some attention that can deal with the context length effectively, maybe I'll try on that ring that you have 😄

Awesome, I'll confirm later on if it's fixed.

Edit: All good!

MarcusLoppe avatar Apr 29 '24 03:04 MarcusLoppe