PEER-pytorch icon indicating copy to clipboard operation
PEER-pytorch copied to clipboard

Usage with x-transformers

Open TKassis opened this issue 1 year ago • 20 comments

PEER looks like an interesting approach and thanks for implementing so cleanly! I do have a quick question though about recommended usage with x-transformers. Would something like this be a good way of using it?

import torch
from PEER_pytorch import PEER
from x_transformers import ContinuousTransformerWrapper, Encoder

peer = PEER(
    dim = 512,
    heads = 8,                   
    num_experts = 1_000_000,     
    num_experts_per_head = 16,   
    dim_key = 128,
    pre_rmsnorm = True
).cuda()


pre_peer = ContinuousTransformerWrapper(
    max_seq_len = 1024,
    attn_layers = Encoder(
        dim = 512,
        depth = 6,
        heads = 8
    )
)

post_peer = ContinuousTransformerWrapper(
    max_seq_len = 1024,
    attn_layers = Encoder(
        dim = 512,
        depth = 6,
        heads = 8
    )
)


x = torch.randn(2, 1024, 512).cuda()

out = pre_peer(x)
out = peer(out) + out
out = post_peer(out)`

TKassis avatar Jul 17 '24 13:07 TKassis

oh hey Tim! good to hear from you again

yea, so there's no way to easily slot this into x-transformers atm, but if you'd like, we could discuss how to build some feature to do so on discord

lucidrains avatar Jul 17 '24 13:07 lucidrains

@TKassis in the original product key memory paper Lample et al., i think optimal placement was in the middle of the network, so you are doing it right

lucidrains avatar Jul 17 '24 13:07 lucidrains

Got it, thanks I'll try it out!

TKassis avatar Jul 17 '24 13:07 TKassis

@TKassis sg, let me know if you see (or don't see) anything!

lucidrains avatar Jul 17 '24 13:07 lucidrains

Unfortunately, I'm running out of memory even with only 2500 experts on a 48 GB A6000 Ada.

        self.pre_peer = ContinuousTransformerWrapper(
            max_seq_len=0,
            attn_layers=Encoder(
                dim=768,
                depth=6,
                heads=12,
                attn_flash=True,
            ),
            scaled_sinu_pos_emb=True,
        )

        self.peer = PEER(
            dim = 768,
            heads = 8,                   # tested up to 32 - (hk = heads * num_experts_per_head (16))
            num_experts = 2500,     # he chose 1 million
            num_experts_per_head = 16,   # he settled on 16, but was 32 in PKM paper
            dim_key = 128,
            pre_rmsnorm = True
        )

        self.post_peer = ContinuousTransformerWrapper(
            max_seq_len=0,
            attn_layers=Encoder(
                dim=768,
                depth=6,
                heads=12,
                attn_flash=True,
            ),
            use_abs_pos_emb=False,
        )
`

TKassis avatar Jul 17 '24 13:07 TKassis

@TKassis want to give this wrapper a try?

lucidrains avatar Jul 17 '24 14:07 lucidrains

Thank you, I gave it a try this morning with the ChunkedPEER wrapper on v0.1.9, unfortunately still running out of memory with 2500 experts. The original unsplit model (were I to combine pre_peer and post_peer) works without any issues). I guess this is designed for DeepMind compute resources :-)

TKassis avatar Jul 18 '24 13:07 TKassis

@TKassis ah ok, thanks for testing it out!

how long are the sequences you are working with?

lucidrains avatar Jul 18 '24 13:07 lucidrains

512

TKassis avatar Jul 18 '24 13:07 TKassis

@TKassis ok, i'll do some profiling later this weekend, thank you!

lucidrains avatar Jul 18 '24 14:07 lucidrains

I use PEER and PKAttention in middle layer of transformers which is 12 layers. ` pk_attn = PKAttention(dim=1536, num_key_values=200x200,pre_rmsnorm=True) peer_mlp = PEER( dim = 1536, heads = 8, num_experts = 200x200, num_experts_per_head = 16, dim_key = 128, pre_rmsnorm = True )

`

forward: x = x + pk_attn(x) x = x + peer_mlp(x)

The good news is that memory does not out in the 32GB v100, and the flops is well.

The bad news is that ppl curves are not so smooth and ideal!

The question then is whether pk_attn and peer_mlp can be used together?

junphine avatar Aug 22 '24 09:08 junphine

image

junphine avatar Aug 22 '24 11:08 junphine

@junphine thanks for testing it out

could you try this improvisation and see if it is any more stable?

lucidrains avatar Aug 22 '24 13:08 lucidrains

@lucidrains Yes, PEERLora is much more stable, with init: self.proj_in.weight.normal_(std=dim**-0.5) self.proj_out.weight.normal_(std=dim_inner**-0.5) self.proj_in_lora_a.weight.normal_(std=dim**-0.5) self.proj_in_lora_b.weight.normal_(std=dim_inner**-0.5) self.proj_out_lora_a.weight.normal_(std=dim_inner**-0.5) self.proj_out_lora_b.weight.normal_(std=dim**-0.5)

But it should takes longer training time to verify. Because I find the value of lora_in_hidden tends to be very large.

junphine avatar Aug 23 '24 05:08 junphine

image

junphine avatar Aug 23 '24 10:08 junphine

@junphine nice, i added in some better init as well, thanks for reporting these results!

lucidrains avatar Aug 23 '24 14:08 lucidrains

@lucidrains Unfortunate, the PEERLora layer didn't seem to be beneficial, when I removed it (replaced it with MLP) or added it, the ppl curve didn't change at all. The two curves coincide perfectly.

junphine avatar Aug 26 '24 09:08 junphine

@junphine ah, that is unfortunate

how about the original formulation, once stabilized of course?

lucidrains avatar Aug 26 '24 13:08 lucidrains

@lucidrains I see the benefits of original formulation.

By increasing the number of exports from 200x200 to 500x500,decrease dim_key from 768 to 128, the curve is converging. image

The green curve is base model which have 24 layers, 0.9B params limited by gpu memory purple is PEER MLP model,which have 16 laysers,1.2B params

The base model has a leading convergence rate, but lags behind PEER at a later stage, seems intuitive

junphine avatar Aug 27 '24 10:08 junphine

@junphine hey that's great! thank you for sharing this! 🚀

lucidrains avatar Aug 27 '24 13:08 lucidrains