recurrent-memory-transformer-pytorch icon indicating copy to clipboard operation
recurrent-memory-transformer-pytorch copied to clipboard

Question: first read memories

Open pfeatherstone opened this issue 2 years ago • 12 comments

During the first run, mems == None, and the model doesn't attend to any "read" tokens, as per: https://github.com/lucidrains/recurrent-memory-transformer-pytorch/blob/3be7d43604c6921a7dbdc68f88c7f3c534f82d2a/recurrent_memory_transformer_pytorch/recurrent_memory_transformer.py#L350-L355 Why not attend to read_memory_emb, and replace with

 read_mem_length = mem_length 
 read_mems = repeat(self.read_memory_emb, 'm d -> b m d', b = batch)
 if exists(read_memories): 
     read_mems += read_memories

pfeatherstone avatar Aug 30 '23 14:08 pfeatherstone

@pfeatherstone ohh, yea i could add it and mask it out

is that the only thing keeping it from being onnx-able?

lucidrains avatar Aug 31 '23 17:08 lucidrains

Pretty much. Basically nothing can be optional with tracing and ONNX. So both mems and mask must be specified

pfeatherstone avatar Aug 31 '23 17:08 pfeatherstone

cool, will get it done later today!

lucidrains avatar Aug 31 '23 17:08 lucidrains

@pfeatherstone decided to take the easy way out https://github.com/lucidrains/recurrent-memory-transformer-pytorch/commit/90de2ac64c1ce2d2ef90f3b63dbdcecf8af2a024 let me know if this does it

lucidrains avatar Aug 31 '23 19:08 lucidrains

You still need to pass mems in the forward pass when tracing. So i think the fix would be to pass mems full of zeros and make the code handle that is if it were None

pfeatherstone avatar Aug 31 '23 20:08 pfeatherstone

Then probably need some control flow in a torch.jit.script function which handles the different cases

pfeatherstone avatar Aug 31 '23 20:08 pfeatherstone

@pfeatherstone have you tried this manually?

lucidrains avatar Aug 31 '23 20:08 lucidrains

Yeah I have. I could submit a PR but I didn't know if passing zeros and allowing to attend to that was ok.

pfeatherstone avatar Aug 31 '23 20:08 pfeatherstone

@pfeatherstone it is harmless to attend to it, as the 0s get summed to the read mem positional embeddings

however, decided to add a keyword argument on forward that can mask out the read memories if need be

lucidrains avatar Aug 31 '23 21:08 lucidrains

@pfeatherstone all good?

lucidrains avatar Sep 02 '23 14:09 lucidrains

Ah yes I need to try. Can it wait till Tuesday?

pfeatherstone avatar Sep 02 '23 15:09 pfeatherstone

@pfeatherstone yea, just checking, take your time

lucidrains avatar Sep 02 '23 15:09 lucidrains