recurrent-memory-transformer-pytorch icon indicating copy to clipboard operation
recurrent-memory-transformer-pytorch copied to clipboard

Feature request: make JIT and ONNX export work

Open pfeatherstone opened this issue 2 years ago • 4 comments

net = RecurrentMemoryTransformer(
    seq_len=1024,
    num_tokens=256,
    num_memory_tokens=128,
    dim=512,
    depth=1,
    causal=True,
    heads=4,
    dim_head=128,
    use_flash_attn=True,
    rotary_pos_emb=True
).eval()

x = torch.randint(0, 256, (8, 1024))

jit = torch.jit.trace(net, (x,))

x = torch.randint(0, 256, (8, 1024))
l = torch.randint(100, x.shape[1], size=(x.shape[0],))
m = lengths_to_padding_mask(x.shape[1], l)

l1, mems, _ = net(x, mask=m)
l2, mems, _ = net(x, mems, mask=m)
l3, mems, _ = jit(x, mask=m)
l4, mems, _ = jit(x, mems, mask=m)

torch.testing.assert_close(l1, l3)
torch.testing.assert_close(l2, l4)

It would be great if the above worked.

pfeatherstone avatar Aug 29 '23 13:08 pfeatherstone

@pfeatherstone ahh, yea, i can look into that

care to share what you are seeing on your dataset with this approach?

lucidrains avatar Aug 29 '23 15:08 lucidrains

I haven't been able to train my models yet just with normal transformers, using larger context lengths (my weird TTS + STT system). CTC loss isn't converging at all. So haven't attempted a proper run with RMT architecture in the STT model. But setting it up with RMT while debugging the other one. I will let you know if i find success. I'm worried that training 2 transformers in tandem simply doesn't work for reasons. Either because of stupidly slow convergence, too lower batch size, or other reasons... Don't know. I've been looking at shifted tokens, scale_norm and other tricks to help with convergence. But i'm not getting any luck. I'm tempted to try RWKV as they claim really fast convergence. Either way, I'm going to need something like RMT in the end so i can have a well defined streaming architecture on the STT side.

pfeatherstone avatar Aug 29 '23 15:08 pfeatherstone

oh got it, makes sense

lucidrains avatar Aug 29 '23 15:08 lucidrains

Gave this a go, it turns out that torch.jit.trace() doesn't accept None in example_inputs. So we cannot trace with mems not None and expect to work when None, or vice versa. My workaround is to pass mems=torch.zeros(B,num_memory_tokens,dim) in the first pass. Which means you're attending to self.read_memory_emb ONLY in the first pass. Don't know if that's allowed.

pfeatherstone avatar Aug 30 '23 14:08 pfeatherstone